diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index efd14e8cbe..c0e2c60046 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -5,14 +5,14 @@ jobs: runs-on: ubuntu-latest steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: - go-version: 1.19.x + go-version: 1.20.x - name: checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -22,10 +22,6 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - args: --timeout=5m - name: install deps run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: gofmt @@ -35,24 +31,35 @@ jobs: go generate ./... git update-index --assume-unchanged go.mod git update-index --assume-unchanged go.sum - if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after runing go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after running go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + + # hack to ensure golanglint process generated files + - name: remove "generated by" comments from generated files + run: | + find . -type f -name '*.go' -exec sed -i 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + # on macos: find . -type f -name '*.go' -exec sed -i '' -E 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + args: --timeout=5m test: strategy: matrix: - go-version: [1.19.x] + go-version: [1.20.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} needs: - staticcheck steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: checkout code - uses: actions/checkout@v2 - - uses: actions/cache@v2 + uses: actions/checkout@v3 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -63,10 +70,19 @@ jobs: restore-keys: | ${{ runner.os }}-go- - name: install deps - run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest + run: | + go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest + go install github.com/ethereum/go-ethereum/cmd/abigen@v1.12.0 + go install github.com/consensys/gnark-solidity-checker@latest + sudo add-apt-repository ppa:ethereum/ethereum + sudo apt-get update + sudo apt-get install solc - name: Test run: | - go test -v -short -timeout=30m ./... + go test -v -short -tags=solccheck -timeout=30m ./... + - name: Test race + run: | + go test -v -short -race -timeout=30m slack-workflow-status-failed: if: failure() @@ -78,7 +94,7 @@ jobs: steps: - name: Notify slack -- workflow failed id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { @@ -86,7 +102,8 @@ jobs: "repo": "${{ github.repository }}", "status": "FAIL", "title": "${{ github.event.pull_request.title }}", - "pr": "${{ github.event.pull_request.head.ref }}" + "pr": "${{ github.event.pull_request.head.ref }}", + "failed_step_url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}/" } env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} @@ -101,7 +118,7 @@ jobs: steps: - name: Notify slack -- workflow succeeded id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 3ef5179a51..d894abb043 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -9,14 +9,14 @@ jobs: runs-on: ubuntu-latest steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: - go-version: 1.19.x + go-version: 1.20.x - name: checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -26,10 +26,8 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - args: --timeout=5m + + - name: install deps run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: gofmt @@ -39,24 +37,36 @@ jobs: go generate ./... git update-index --assume-unchanged go.mod git update-index --assume-unchanged go.sum - if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after runing go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after running go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + + # hack to ensure golanglint process generated files + - name: remove "generated by" comments from generated files + run: | + find . -type f -name '*.go' -exec sed -i 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + # on macos: find . -type f -name '*.go' -exec sed -i '' -E 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + args: --timeout=5m + test: strategy: matrix: - go-version: [1.19.x] + go-version: [1.20.x] os: [ubuntu-latest, windows-latest, macos-latest] runs-on: ${{ matrix.os }} needs: - staticcheck steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: checkout code - uses: actions/checkout@v2 - - uses: actions/cache@v2 + uses: actions/checkout@v3 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -86,7 +96,7 @@ jobs: steps: - name: Notify slack -- workflow failed id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { @@ -94,7 +104,8 @@ jobs: "repo": "${{ github.repository }}", "status": "FAIL", "title": "push to ${{ github.event.push.base_ref }}", - "pr": "" + "pr": "", + "failed_step_url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}/" } env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} @@ -109,7 +120,7 @@ jobs: steps: - name: Notify slack -- workflow succeeded id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { diff --git a/.gitignore b/.gitignore index 5f2754da70..1be9af1ba2 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,9 @@ gnarkd/circuits/** # Jetbrains stuff .idea/ + +# go workspace +go.work +go.work.sum + +examples/gbotrel/** \ No newline at end of file diff --git a/README.md b/README.md index 143ca2814a..91e014ab2c 100644 --- a/README.md +++ b/README.md @@ -99,14 +99,14 @@ func (circuit *CubicCircuit) Define(api frontend.API) error { func main() { // compiles our circuit into a R1CS var circuit CubicCircuit - ccs, _ := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &circuit) + ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) // groth16 zkSNARK: Setup pk, vk, _ := groth16.Setup(ccs) // witness definition assignment := CubicCircuit{X: 3, Y: 35} - witness, _ := frontend.NewWitness(&assignment, ecc.BN254) + witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) publicWitness, _ := witness.Public() // groth16: Prove & Verify diff --git a/backend/backend.go b/backend/backend.go index 73c3f05bae..84a05271ab 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -15,11 +15,7 @@ // Package backend implements Zero Knowledge Proof systems: it consumes circuit compiled with gnark/frontend. package backend -import ( - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/logger" - "github.com/rs/zerolog" -) +import "github.com/consensys/gnark/constraint/solver" // ID represent a unique ID for a proving scheme type ID uint16 @@ -50,26 +46,20 @@ func (id ID) String() string { } } -// ProverOption defines option for altering the behaviour of the prover in +// ProverOption defines option for altering the behavior of the prover in // Prove, ReadAndProve and IsSolved methods. See the descriptions of functions // returning instances of this type for implemented options. type ProverOption func(*ProverConfig) error // ProverConfig is the configuration for the prover with the options applied. type ProverConfig struct { - Force bool // defaults to false - HintFunctions map[hint.ID]hint.Function // defaults to all built-in hint functions - CircuitLogger zerolog.Logger // defaults to gnark.Logger + SolverOpts []solver.Option } // NewProverConfig returns a default ProverConfig with given prover options opts // applied. func NewProverConfig(opts ...ProverOption) (ProverConfig, error) { - log := logger.Logger() - opt := ProverConfig{CircuitLogger: log, HintFunctions: make(map[hint.ID]hint.Function)} - for _, v := range hint.GetRegistered() { - opt.HintFunctions[hint.UUID(v)] = v - } + opt := ProverConfig{} for _, option := range opts { if err := option(&opt); err != nil { return ProverConfig{}, err @@ -78,42 +68,10 @@ func NewProverConfig(opts ...ProverOption) (ProverConfig, error) { return opt, nil } -// IgnoreSolverError is a prover option that indicates that the Prove algorithm -// should complete even if constraint system is not solved. In that case, Prove -// will output an invalid Proof, but will execute all algorithms which is useful -// for test and benchmarking purposes. -func IgnoreSolverError() ProverOption { - return func(opt *ProverConfig) error { - opt.Force = true - return nil - } -} - -// WithHints is a prover option that specifies additional hint functions to be used -// by the constraint solver. -func WithHints(hintFunctions ...hint.Function) ProverOption { - log := logger.Logger() - return func(opt *ProverConfig) error { - // it is an error to register hint function several times, but as the - // prover already checks it then omit here. - for _, h := range hintFunctions { - uuid := hint.UUID(h) - if _, ok := opt.HintFunctions[uuid]; ok { - log.Warn().Int("hintID", int(uuid)).Str("name", hint.Name(h)).Msg("duplicate hint function") - } else { - opt.HintFunctions[uuid] = h - } - } - return nil - } -} - -// WithCircuitLogger is a prover option that specifies zerolog.Logger as a destination for the -// logs printed by api.Println(). By default, uses gnark/logger. -// zerolog.Nop() will disable logging -func WithCircuitLogger(l zerolog.Logger) ProverOption { +// WithSolverOptions specifies the constraint system solver options. +func WithSolverOptions(solverOpts ...solver.Option) ProverOption { return func(opt *ProverConfig) error { - opt.CircuitLogger = l + opt.SolverOpts = solverOpts return nil } } diff --git a/backend/groth16/bellman_test.go b/backend/groth16/bellman_test.go index af74b4e47d..eb4649d34b 100644 --- a/backend/groth16/bellman_test.go +++ b/backend/groth16/bellman_test.go @@ -23,62 +23,62 @@ func TestVerifyBellmanProof(t *testing.T) { ok bool }{ { - "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6w==", + "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6wAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lvQLU/KqgFhsLkt/5C/scqs7nWR+eYtyPdWiLVBux9GblT4AhHYMdCgwQfSJcudvsgV6fXoK+DUSRgJ++Nqt+Wvb7GlYlHpxCysQhz26TTu8Nyo7zpmVPH92+UYmbvbQCSvX2BhWtvkfHmqDVjmSIQ4RUMfeveA1KZbSf999NE4qKK8Do+8oXcmTM4LZVmh1rlyqznIdFXPN7x3pD4E0gb6/y69xtWMChv9654FMg05bAdueKt9uA4BEcAbpkdHF", "LcMT3OOlkHLzJBKCKjjzzVMg+r+FVgd52LlhZPB4RFg=", true, }, { - "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6w==", + "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6wAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lvQLU/KqgFhsLkt/5C/scqs7nWR+eYtyPdWiLVBux9GblT4AhHYMdCgwQfSJcudvsgV6fXoK+DUSRgJ++Nqt+Wvb7GlYlHpxCysQhz26TTu8Nyo7zpmVPH92+UYmbvbQCSvX2BhWtvkfHmqDVjmSIQ4RUMfeveA1KZbSf999NE4qKK8Do+8oXcmTM4LZVmh1rlyqznIdFXPN7x3pD4E0gb6/y69xtWMChv9654FMg05bAdueKt9uA4BEcAbpkdHF", "cmzVCcRVnckw3QUPhmG4Bkppeg4K50oDQwQ9EH+Fq1s=", false, }, { - "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6w==", + "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6wAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lvQLU/KqgFhsLkt/5C/scqs7nWR+eYtyPdWiLVBux9GblT4AhHYMdCgwQfSJcudvsgV6fXoK+DUSRgJ++Nqt+Wvb7GlYlHpxCysQhz26TTu8Nyo7zpmVPH92+UYmbvbQCSvX2BhWtvkfHmqDVjmSIQ4RUMfeveA1KZbSf999NE4qKK8Do+8oXcmTM4LZVmh1rlyqznIdFXPN7x3pD4E0gb6/y69xtWMChv9654FMg05bAdueKt9uA4BEcAbpkdHF", "cmzVCcRVnckw3QUPhmG4Bkppeg4K50oDQwQ9EH+Fq1s=", false, }, { - "kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6knkzSwcsialcheg69eZYPK8EzKRVI5FrRHKi8rgB+R5jyPV70ejmYEx1neTmfYKODRmARr/ld6pZTzBWYDfrCkiS1QB+3q3M08OQgYcLzs/vjW4epetDCmk0K1CEGcWdh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6jgld4oAppAOzvQ7eoIx2tbuuKVSdbJm65KDxl/T+boaYnjRm3omdETYnYRk3HAhrAeWpefX+dM/k7PrcheInnxHUyjzSzqlN03xYjg28kdda9FZJaVsQKqdEJ/St9ivXAAAAAZae/nTwyDn5u+4WkhZ76991cGB/ymyGpXziT0bwS86pRw/AcbpzXmzK+hq+kvrvpw==", + "kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6knkzSwcsialcheg69eZYPK8EzKRVI5FrRHKi8rgB+R5jyPV70ejmYEx1neTmfYKODRmARr/ld6pZTzBWYDfrCkiS1QB+3q3M08OQgYcLzs/vjW4epetDCmk0K1CEGcWdh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6jgld4oAppAOzvQ7eoIx2tbuuKVSdbJm65KDxl/T+boaYnjRm3omdETYnYRk3HAhrAeWpefX+dM/k7PrcheInnxHUyjzSzqlN03xYjg28kdda9FZJaVsQKqdEJ/St9ivXAAAAAZae/nTwyDn5u+4WkhZ76991cGB/ymyGpXziT0bwS86pRw/AcbpzXmzK+hq+kvrvpwAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "sStVLdyxqInmv76iaNnRFB464lGq48iVeqYWSi2linE9DST0fTNhxSnvSXAoPpt8tFsanj5vPafC+ij/Fh98dOUlMbO42bf280pOZ4lm+zr63AWUpOOIugST+S6pq9zeB0OHp2NY8XFmriOEKhxeabhuV89ljqCDjlhXBeNZwM5zti4zg89Hd8TbKcw46jAsjIJe2Siw3Th7ELQQKR5ucX50f0GISmnOSceePPdvjbGJ8fSFOnSmSp8dK7uyehrU", "", true, }, { - "mY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpf+4uOyv3gPZe54SYGM4pfhteqJpwFQxdlpwXWyYxMTNaSLDj8VtSn/EJaSu+P6nFmWsda3mTYUPYMZzWE4hMqpDgFPcJhw3prArMThDPbR3Hx7E6NRAAR0LqcrdtsbDqu2T0tto1rpnFILdvHL4PqEUfTmF2mkM+DKj7lKwvvZUbukqBwLrnnbdfyqZJryzGAMIa2JvMEMYszGsYyiPXZvYx6Luk54oWOlOrwEKrCY4NMPwch6DbFq6KpnNSQwOmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpgRYCz7wpjk57X+NGJmo85tYKc+TNa1rT4/DxG9v6SHkpXmmPeHhzIIW8MOdkFjxB5o6Qn8Fa0c6Tt6br2gzkrGr1eK5/+RiIgEzVhcRrqdY/p7PLmKXqawrEvIv9QZ3AAAAAoo8rTzcIp5QvF3USzv2Lz99z43CPVkjHB1ejzj/SjzKNa54GiDzHoCoAL0xKLjRSqeL98AF0V1+cRI8FwJjOcMgf0gDmjzwiv3ppbPZKqJR7Go+57k02670lfG6s1MM0A==", + "mY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpf+4uOyv3gPZe54SYGM4pfhteqJpwFQxdlpwXWyYxMTNaSLDj8VtSn/EJaSu+P6nFmWsda3mTYUPYMZzWE4hMqpDgFPcJhw3prArMThDPbR3Hx7E6NRAAR0LqcrdtsbDqu2T0tto1rpnFILdvHL4PqEUfTmF2mkM+DKj7lKwvvZUbukqBwLrnnbdfyqZJryzGAMIa2JvMEMYszGsYyiPXZvYx6Luk54oWOlOrwEKrCY4NMPwch6DbFq6KpnNSQwOmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpgRYCz7wpjk57X+NGJmo85tYKc+TNa1rT4/DxG9v6SHkpXmmPeHhzIIW8MOdkFjxB5o6Qn8Fa0c6Tt6br2gzkrGr1eK5/+RiIgEzVhcRrqdY/p7PLmKXqawrEvIv9QZ3AAAAAoo8rTzcIp5QvF3USzv2Lz99z43CPVkjHB1ejzj/SjzKNa54GiDzHoCoAL0xKLjRSqeL98AF0V1+cRI8FwJjOcMgf0gDmjzwiv3ppbPZKqJR7Go+57k02670lfG6s1MM0AAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "g53N8ecorvG2sDgNv8D7quVhKMIIpdP9Bqk/8gmV5cJ5Rhk9gKvb4F0ll8J/ZZJVqa27OyciJwx6lym6QpVK9q1ASrqio7rD5POMDGm64Iay/ixXXn+//F+uKgDXADj9AySri2J1j3qEkqqe3kxKthw94DzAfUBPncHfTPazVtE48AfzB1KWZA7Vf/x/3phYs4ckcP7ZrdVViJVLbUgFy543dpKfEH2MD30ZLLYRhw8SatRCyIJuTZcMlluEKG+d", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEs=", true, }, { - "tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3sUmODSJXAQmAFBVnS2t+Xzf5ZCr1gCtMiJVjQ48/nob/SkrS4cTHHjbKIVS9cdD/BG/VDrZvBt/dPqXmdUFyFuTTMrViagR57YRrDmm1qm5LQ/A8VwUBdiArwgRQXH9jsYhgVmfcRAjJytrbYeR6ck4ZfmGr6x6akKiBLY4B1l9LaHTyz/6KSM5t8atpuR3HBJZfbBm2/K8nnYTl+mAU/EnIN3YQdUd65Hsd4Gtf6VT2qfz6hcrSgHutxR1usIL2tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3kyU9X4Kqjx6I6zYwVbn7PWbiy3OtY277z4ggIqW6AuDgzUeIyG9a4stMeQ07mOV/Ef4faj+eh4GJRKjJm7aUTYJCSAGY6klOXNoEzB54XF4EY5pkMPfW73SmxJi9B0aHAAAAEJGVg8trc1JcL8WfwX7A5FGZ7epiPqnQzrUxuiRSLUkGaLWBgwZusz3M8KN2QBqa/IIm0xOg40+xhjQxJduo4ACd2gHQa3+2G9L1hGIsziSuEjv1HfuP1sVw28u8W8JRWJIBLWGzDuj16M4Uag4qLSdAn3UhMTRwPQN+5kf26TTisoQK38r0gSCZ1EIDsOcDAavhjj+Z+/BPfWua2OBVxlJjNyxnafwr5BiE2H9OElh5GQBLnmB/emLOY6x5SGUANpPY9NiYvki/NgyRR/Cw4e+34Ifc4dMAIwgKmO/6+9uN+EQwPe23xGSWr0ZgBDbIH5bElW/Hfa0DAaVpd15G/JjZVDkn/iwF3l2EEeNmeMrlI8AFL5P//oprobFhfGQjJKW/cEP+nK1R+BORN3+iH/zLfw3Hp1pTzbb7tgiRWrXPKt9WknZ1oTDfFOuUl9wwaLg3PBFwxXebcMuFVjEuZYWOlW1P5UvE/KMoa/jSKbLbClJkodBDNaxslIdjzYCGM6Hgc5x1moKdljt5yGzWCHFxETgU/EKagOA6s8b+uuY8Goxl5gGsEb3Wasy6rwpHro3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKQ==", + "tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3sUmODSJXAQmAFBVnS2t+Xzf5ZCr1gCtMiJVjQ48/nob/SkrS4cTHHjbKIVS9cdD/BG/VDrZvBt/dPqXmdUFyFuTTMrViagR57YRrDmm1qm5LQ/A8VwUBdiArwgRQXH9jsYhgVmfcRAjJytrbYeR6ck4ZfmGr6x6akKiBLY4B1l9LaHTyz/6KSM5t8atpuR3HBJZfbBm2/K8nnYTl+mAU/EnIN3YQdUd65Hsd4Gtf6VT2qfz6hcrSgHutxR1usIL2tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3kyU9X4Kqjx6I6zYwVbn7PWbiy3OtY277z4ggIqW6AuDgzUeIyG9a4stMeQ07mOV/Ef4faj+eh4GJRKjJm7aUTYJCSAGY6klOXNoEzB54XF4EY5pkMPfW73SmxJi9B0aHAAAAEJGVg8trc1JcL8WfwX7A5FGZ7epiPqnQzrUxuiRSLUkGaLWBgwZusz3M8KN2QBqa/IIm0xOg40+xhjQxJduo4ACd2gHQa3+2G9L1hGIsziSuEjv1HfuP1sVw28u8W8JRWJIBLWGzDuj16M4Uag4qLSdAn3UhMTRwPQN+5kf26TTisoQK38r0gSCZ1EIDsOcDAavhjj+Z+/BPfWua2OBVxlJjNyxnafwr5BiE2H9OElh5GQBLnmB/emLOY6x5SGUANpPY9NiYvki/NgyRR/Cw4e+34Ifc4dMAIwgKmO/6+9uN+EQwPe23xGSWr0ZgBDbIH5bElW/Hfa0DAaVpd15G/JjZVDkn/iwF3l2EEeNmeMrlI8AFL5P//oprobFhfGQjJKW/cEP+nK1R+BORN3+iH/zLfw3Hp1pTzbb7tgiRWrXPKt9WknZ1oTDfFOuUl9wwaLg3PBFwxXebcMuFVjEuZYWOlW1P5UvE/KMoa/jSKbLbClJkodBDNaxslIdjzYCGM6Hgc5x1moKdljt5yGzWCHFxETgU/EKagOA6s8b+uuY8Goxl5gGsEb3Wasy6rwpHro3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lgFU4Jyo9GdHL7w31u3zXc8RQRnHVarZWNfd0lD45GvvQtwrZ1Y1OKB4T29a79UagPHOdk1S0k0hYAYQyyNAfRUzde1HP8R+2dms75gGZEnx2tXexEN+BVjRJfC8PR1lFJa6xvsEx5uSrOZzKmoMfCwcA55SMT5jFo4+KyWg2wP5OnFPx7XTdEKvf5YhpY0krQKiq3OUu79EwjNF1xV1+iLxx2KEIyK7RSYxO1BHrKOGOEzxSUK00MA+YVHe+DvW", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEtiLNj7hflFeVnNXPguxyoqkI/V7pGJtXBpH5N+RswQNA0b23aM33aH0HKHOWoGY/T/L7TQzYFGJ3vTLiXDFZg1OVqkGOMvqAgonOrHGi6IgcALyUMyCKlL5BQY23SeILJpYKolybJNwJfbjxpg0Oz+D2fr7r9XL1GMvgblu52bVQT1fR8uCRJfSsgA2OGw6k/MpKDCfMcjbR8jnZa8ROEvF4cohm7iV1788Vp2/2bdcEZRQSoaGV8pOmA9EkqzJVRABjkDso40fnQcm2IzjBUOsX+uFExVan56/vl9VZVwB0wnee3Uxiredn0kOayiPB16yimxXCDet+M+0UKjmIlmXYpkrCDrH0dn53w+U3OHqMQxPDnUpYBxadM1eI8xWFFxzaLkvega0q0DmEquyY02yiTqo+7Q4qaJVTLgu6/8ekzPxGKRi845NL8gRgaTtM3kidDzIQpyODZD0yeEZDY1M+3sUKHcVkhoxTQBTMyKJPc+M5DeBL3uaWMrvxuL6q8+X0xeBt+9kguPUNtIYqUgPAaXvM2i041bWHTJ0dZLyDJVOyzGaXRaF4mNkAuh4Et6Zw5PuOpMM2mI1oFKEZj7", true, }, { - "kY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEqvAtYaSs5qW3riOiiRFoLp7MThW4vCEhK0j8BZY5ZM/tnjB7mrLB59kGvzpW8PM/AoQRIWzyvO3Dxxfyj/UQcQRw+KakVRvrFca3Vy2K5cFwxYHwl6PFDM+OmGrlgOCoqZtY1SLOd+ovmFOODKiHBZzDZhC/lRfjKVy4LzI7AXDuFn4tlWoT7IsJyy6lYNaWFfLjYZPAsrv1gXJ1NYat5B6E0Pnz5C67u2Uigmlol2D91re3oAqIo+r8kiyFKOSBkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEooG0cMN47zQor6qj0owuxJjn5Ymrcd/FCQ1ud4cKoUlNaGWIekSjxJEB87elMy5oEUlUzVI9ObMm+2SE3Udgws7pkMM8fgQUQUqUVyc7sNCE9m/hQzlwtbXrNSS5Pb+6AAAAEaMO2hzDmr41cml4ktH+m9acCaUtck/ivOVANQi6qsQmhMOfvFgIzMwTqypsQVKWAKSOBQQCGv0o3lP8GJ5Y1FDEzH5wXwkPDEtNYRUkGUqD8dXaPGcZ+WNzT4KWqJlw36clSvUNFNDZKkKj7JPk/gK6MUBsavX/xzl+SOWYmxdu3Wd9rQm0yqNthoLKarQL9or9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwJRRD4g3oAXA2lVh0tgNRTyAvXfg1NOb4s6wX5YurLvawr0gTVZ6A0gRds3lPtjY14+8nB2MQrmYJfHQbvBWY745Q1GQqn3atz7M0HqNl+ebawyRB3lVmkaCHIIhtoX0zQ==", + "kY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEqvAtYaSs5qW3riOiiRFoLp7MThW4vCEhK0j8BZY5ZM/tnjB7mrLB59kGvzpW8PM/AoQRIWzyvO3Dxxfyj/UQcQRw+KakVRvrFca3Vy2K5cFwxYHwl6PFDM+OmGrlgOCoqZtY1SLOd+ovmFOODKiHBZzDZhC/lRfjKVy4LzI7AXDuFn4tlWoT7IsJyy6lYNaWFfLjYZPAsrv1gXJ1NYat5B6E0Pnz5C67u2Uigmlol2D91re3oAqIo+r8kiyFKOSBkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEooG0cMN47zQor6qj0owuxJjn5Ymrcd/FCQ1ud4cKoUlNaGWIekSjxJEB87elMy5oEUlUzVI9ObMm+2SE3Udgws7pkMM8fgQUQUqUVyc7sNCE9m/hQzlwtbXrNSS5Pb+6AAAAEaMO2hzDmr41cml4ktH+m9acCaUtck/ivOVANQi6qsQmhMOfvFgIzMwTqypsQVKWAKSOBQQCGv0o3lP8GJ5Y1FDEzH5wXwkPDEtNYRUkGUqD8dXaPGcZ+WNzT4KWqJlw36clSvUNFNDZKkKj7JPk/gK6MUBsavX/xzl+SOWYmxdu3Wd9rQm0yqNthoLKarQL9or9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwJRRD4g3oAXA2lVh0tgNRTyAvXfg1NOb4s6wX5YurLvawr0gTVZ6A0gRds3lPtjY14+8nB2MQrmYJfHQbvBWY745Q1GQqn3atz7M0HqNl+ebawyRB3lVmkaCHIIhtoX0zQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "jqPSA/XKqZDJnRSmM0sJxbrFv7GUcA45QMysIx1xTsI3+2iysF5Tr68565ZuO65qjo2lklZpQo+wtyKSA/56EaKOJZCZhSvDdBEdvVYJCjmWusuK5qav7xZO0w5W1qRiEgIdcGUz5V7JHqfRf4xI6/uUD846alyzzNjxQtKErqJbRw6yyBO6j6box363pinjiMTzU4w/qltzFuOEpKxy/H3vyH8RcsF24Ou/Rb6vfR7cSLtLwCsf/BMtPcsQfdRK", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEtiLNj7hflFeVnNXPguxyoqkI/V7pGJtXBpH5N+RswQNA0b23aM33aH0HKHOWoGY/T/L7TQzYFGJ3vTLiXDFZg1OVqkGOMvqAgonOrHGi6IgcALyUMyCKlL5BQY23SeILJpYKolybJNwJfbjxpg0Oz+D2fr7r9XL1GMvgblu52bVQT1fR8uCRJfSsgA2OGw6k/MpKDCfMcjbR8jnZa8ROEvF4cohm7iV1788Vp2/2bdcEZRQSoaGV8pOmA9EkqzJVRABjkDso40fnQcm2IzjBUOsX+uFExVan56/vl9VZVwB0wnee3Uxiredn0kOayiPB16yimxXCDet+M+0UKjmIlmXYpkrCDrH0dn53w+U3OHqMQxPDnUpYBxadM1eI8xWFFxzaLkvega0q0DmEquyY02yiTqo+7Q4qaJVTLgu6/8ekzPxGKRi845NL8gRgaTtM3kidDzIQpyODZD0yeEZDY1M+3sUKHcVkhoxTQBTMyKJPc+M5DeBL3uaWMrvxuL6q8+X0xeBt+9kguPUNtIYqUgPAaXvM2i041bWHTJ0dZLyDJVOyzGaXRaF4mNkAuh4Et6Zw5PuOpMM2mI1oFKEZj7Xqf/yAmy/Le3GfJnMg5vNgE7QxmVsjuKUP28iN8rdi4=", true, }, { - "pQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNupQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNuhKgxmPN2c94UBtlYc0kZS6CwyMEEV/nVGSjajEZPdnpbK7fEcPd0hWNcOxKWq8qBBPfT69Ore74buf8C26ZTyKnjgMsGCvoDAMOsA07DjjQ1nIkkwIGFFUT3iMO83TdEpWgV/2z7WT9axNH/QFPOjXvwQJFnC7hLxHnX6pgKOdAaioKdi6FX3Y2SwWEO3UuxFd3KwsrZ2+mma/W3KP/cPpSzqyHa5VaJwOCw6vSM4wHSGKmDF4TSrrnMxzIYiTbTpQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNulrwLi5GjMxD6BKzMMN9+7xFuO7txLCEIhGrIMFIvqTw1QFAO4rmAgyG+ljlYTfWHAkzqvImL1o8dMHhGOTsMLLMg39KsZVqalZwwL3ckpdAf81OJJeWCpCuaSgSXnWhJAAAAEph8ULgPc1Ia5pUdcBzvXnoB4f6dNaLD9MVNN62NaBqJzmdvGnGBujEjn2QZCk/jaKjnBFrS+EQj+rewVlx4CFJpYQhI/6cDVcfdlXN2cxPMzId1NfeiAh800mc9KzMCZJk9JdZu0HbwaalHqgscl4GPumn6rHQHRo2XrlDjwdkQ2ptwpto9meVcoL3SASNdqpSKBAYZ64QscekzfssIpXyNmgY807Z9KwnuyAPbGLXGJ910qKFO0wTQd/TvHGxoJ5hmEMoMQbEPxJo9igwqkOANTEZ0nt6urUIY06Kg4x0VxCs5VpGv+PoVjZyaYnKvy5k948Qh/f8q3vKhVF8vh6tsnIrY7966IMPocl5St6SKEJg7JCZ6gZN4cYrI90EK0Ir9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwA==", + "pQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNupQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNuhKgxmPN2c94UBtlYc0kZS6CwyMEEV/nVGSjajEZPdnpbK7fEcPd0hWNcOxKWq8qBBPfT69Ore74buf8C26ZTyKnjgMsGCvoDAMOsA07DjjQ1nIkkwIGFFUT3iMO83TdEpWgV/2z7WT9axNH/QFPOjXvwQJFnC7hLxHnX6pgKOdAaioKdi6FX3Y2SwWEO3UuxFd3KwsrZ2+mma/W3KP/cPpSzqyHa5VaJwOCw6vSM4wHSGKmDF4TSrrnMxzIYiTbTpQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNulrwLi5GjMxD6BKzMMN9+7xFuO7txLCEIhGrIMFIvqTw1QFAO4rmAgyG+ljlYTfWHAkzqvImL1o8dMHhGOTsMLLMg39KsZVqalZwwL3ckpdAf81OJJeWCpCuaSgSXnWhJAAAAEph8ULgPc1Ia5pUdcBzvXnoB4f6dNaLD9MVNN62NaBqJzmdvGnGBujEjn2QZCk/jaKjnBFrS+EQj+rewVlx4CFJpYQhI/6cDVcfdlXN2cxPMzId1NfeiAh800mc9KzMCZJk9JdZu0HbwaalHqgscl4GPumn6rHQHRo2XrlDjwdkQ2ptwpto9meVcoL3SASNdqpSKBAYZ64QscekzfssIpXyNmgY807Z9KwnuyAPbGLXGJ910qKFO0wTQd/TvHGxoJ5hmEMoMQbEPxJo9igwqkOANTEZ0nt6urUIY06Kg4x0VxCs5VpGv+PoVjZyaYnKvy5k948Qh/f8q3vKhVF8vh6tsnIrY7966IMPocl5St6SKEJg7JCZ6gZN4cYrI90EK0Ir9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "qV2FNaBFqWeL6n9q9OUbCSTcIQvwO0vfaA/f/SxEtLSIaOGIOx8r+WVGFdxmC6i3oOaoEkJWvML7PpKBDtqiK7pKDIaMV5PkV/kQl6UgxZv9OInTwpVPtYcgeeTokG/eBi1qKzJwDoEHVqKeLqrLXJHXhBVQLdoIUOeKj8YMkagVniO9EtK0fW0/9QnRIxXoilxSj5HBEpYwFBitJXRk1ftFGWZFxJXU5PXdRmC+pomyo5Scx+UJQ2NLRWHjKlV0", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEtiLNj7hflFeVnNXPguxyoqkI/V7pGJtXBpH5N+RswQNA0b23aM33aH0HKHOWoGY/T/L7TQzYFGJ3vTLiXDFZg1OVqkGOMvqAgonOrHGi6IgcALyUMyCKlL5BQY23SeILJpYKolybJNwJfbjxpg0Oz+D2fr7r9XL1GMvgblu52bVQT1fR8uCRJfSsgA2OGw6k/MpKDCfMcjbR8jnZa8ROEvF4cohm7iV1788Vp2/2bdcEZRQSoaGV8pOmA9EkqzJVRABjkDso40fnQcm2IzjBUOsX+uFExVan56/vl9VZVwB0wnee3Uxiredn0kOayiPB16yimxXCDet+M+0UKjmIlmXYpkrCDrH0dn53w+U3OHqMQxPDnUpYBxadM1eI8xWFFxzaLkvega0q0DmEquyY02yiTqo+7Q4qaJVTLgu6/8ekzPxGKRi845NL8gRgaTtM3kidDzIQpyODZD0yeEZDY1M+3sUKHcVkhoxTQBTMyKJPc+M5DeBL3uaWMrvxuL6q8+X0xeBt+9kguPUNtIYqUgPAaXvM2i041bWHTJ0dZLyDJVOyzGaXRaF4mNkAuh4Et6Zw5PuOpMM2mI1oFKEZj7Xqf/yAmy/Le3GfJnMg5vNgE7QxmVsjuKUP28iN8rdi4bUp7c0KJpqLXE6evfRrdZBDRYp+rmOLLDg55ggNuwog==", true, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywQ==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "jiGBK+TGHfH8Oadexhdet7ExyIWibSmamWQvffZkyl3WnMoVbTQ3lOks4Mca3sU5qgcaLyQQ1FjFW4g6vtoMapZ43hTGKaWO7bQHsOCvdwHCdwJDulVH16cMTyS9F0BfBJxa88F+JKZc4qMTJjQhspmq755SrKhN9Jf+7uPUhgB4hJTSrmlOkTatgW+/HAf5kZKhv2oRK5p5kS4sU48oqlG1azhMtcHEXDQdcwf9ANel4Z9cb+MQyp2RzI/3hlIx", "", false, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAo3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOQ==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAo3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "hp1iMepdu0rKoBh0NXcw9F9hkiggDIkRNINq2rlvUypPiSmp8U8tDSMeG0YVSovFteecr3THhBJj0qNeEe9jA2Ci64fKG9WT1heMYzEAQKebOErYXYCm9d72n97mYn1XBq+g1Y730XEDv4BIDI1hBDntJcgcj/cSvcILB1+60axJvtyMyuizxUr1JUBUq9njtmJ9m8zK6QZLNqMiKh0f2jokQb5mVhu6v5guW3KIjwQc/oFK/l5ehKAOPKUUggNh", "c9BSUPtO0xjPxWVNkEMfXe7O4UZKpaH/nLIyQJj7iA4=", false, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEI3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1A==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEI3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1AAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "pNeWbxzzJPMsPpuXBXWZgtLic1s0KL8UeLDGBhEjygrv8m1eMM12pzd+r/scvBEHrnEoQHanlNTlWPywaXaFtB5Hd5RMrnbfLbpe16tvtlH2SRbJbGXSpib5uiuSa6z1ExLtXs9nNWiu10eupG6Pq4SNOacCEVvUgSzCzhyLIlz62gq4DlBBWKmEFI7KiFs7kr2EPBjj2m83dbA/GGVgoYYjgBmFX6/srvLADxerZTKG2moOQrmAx9GJ99nwhRbW", "I8C5RcBDPi2n4omt9oOV2rZk9T9xlSV8PQvLeVHjGb00fCVz7AHOIjLJ03ZCTLQwEKkAk9tQWJ6gFTBnG2+0DDHlXcVkwpMafcpS2diKFe0T4fRb0t9mxNzOFiRVcJoeMU1zb/rE4dIMm9rbEPSDnVSOd8tHNnJDkT+/NcNsQ2w0UEVJJRAEnC7G0Y3522RlDLxpTZ6w0U/9V0pLNkFgDCkFBKvpaEfPDJjoEVyCUWDC1ts9LIR43xh3ZZBdcO/HATHoLzxM3Ef11qF+riV7WDPEJfK11u8WGazzCAFhsx0aKkkbnKl7LnypBzwRvrG2JxdLI/oXL0eoIw9woVjqrg6elHudnHDXezDVXjRWMPaU+L3tOW9aqN+OdP4AhtpgT2CoRCjrOIU3MCFqsrCK9bh33PW1gtNeHC78mIetQM5LWZHtw4KNwafTrQ+GCKPelJhiC2x7ygBtat5rtBsJAVF5wjssLPZx/7fqNqifXB7WyMV7J1M8LBQVXj5kLoS9bpmNHlERRSadC0DEUbY9xhIG2xo7R88R0sq04a299MFv8XJNd+IdueYiMiGF5broHD4UUhPxRBlBO3lOfDTPnRSUGS3Sr6GxwCjKO3MObz/6RNxCk9SnQ4NccD17hS/m", false, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1KTXlm8c8yTzLD6blwV1mYLS4nNbNCi/FHiwxgYRI8oK7/JtXjDNdqc3fq/7HLwRBw==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1KTXlm8c8yTzLD6blwV1mYLS4nNbNCi/FHiwxgYRI8oK7/JtXjDNdqc3fq/7HLwRBwAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "iw5yhCCarVRq/h0Klq4tHNdF1j7PxaDn0AfHTxc2hb//Acav53QStwQShQ0BpQJ7sdchkTTJLkhM13+JpPY/I2WIc6DMZdRzw3pRjLSdMUmce7LYbBJOI+/IyuLZH5IXA7sX4r+xrPssIaMiKR3twmmReN9NrSoovLepDsNmzDVraO71B4rkx7uPXvkqvt3Zkr2EPBjj2m83dbA/GGVgoYYjgBmFX6/srvLADxerZTKG2moOQrmAx9GJ99nwhRbW", "I8C5RcBDPi2n4omt9oOV2rZk9T9xlSV8PQvLeVHjGb00fCVz7AHOIjLJ03ZCTLQwEKkAk9tQWJ6gFTBnG2+0DDHlXcVkwpMafcpS2diKFe0T4fRb0t9mxNzOFiRVcJoeMU1zb/rE4dIMm9rbEPSDnVSOd8tHNnJDkT+/NcNsQ2w0UEVJJRAEnC7G0Y3522RlDLxpTZ6w0U/9V0pLNkFgDCkFBKvpaEfPDJjoEVyCUWDC1ts9LIR43xh3ZZBdcO/HATHoLzxM3Ef11qF+riV7WDPEJfK11u8WGazzCAFhsx0aKkkbnKl7LnypBzwRvrG2JxdLI/oXL0eoIw9woVjqrg6elHudnHDXezDVXjRWMPaU+L3tOW9aqN+OdP4AhtpgT2CoRCjrOIU3MCFqsrCK9bh33PW1gtNeHC78mIetQM5LWZHtw4KNwafTrQ+GCKPelJhiC2x7ygBtat5rtBsJAVF5wjssLPZx/7fqNqifXB7WyMV7J1M8LBQVXj5kLoS9bpmNHlERRSadC0DEUbY9xhIG2xo7R88R0sq04a299MFv8XJNd+IdueYiMiGF5broHD4UUhPxRBlBO3lOfDTPnRSUGS3Sr6GxwCjKO3MObz/6RNxCk9SnQ4NccD17hS/mEFt8d4ERZOfmuvD3A0RCPCnx3Fr6rHdm6j+cfn/NM6o=", false, }, diff --git a/internal/backend/bls12-377/groth16/commitment.go b/backend/groth16/bls12-377/commitment.go similarity index 74% rename from internal/backend/bls12-377/groth16/commitment.go rename to backend/groth16/bls12-377/commitment.go index 72506760a4..c267e8a99b 100644 --- a/internal/backend/bls12-377/groth16/commitment.go +++ b/backend/groth16/bls12-377/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls12-377/groth16/commitment_test.go b/backend/groth16/bls12-377/commitment_test.go similarity index 91% rename from internal/backend/bls12-377/groth16/commitment_test.go rename to backend/groth16/bls12-377/commitment_test.go index 18f6f142d1..f18eefb837 100644 --- a/internal/backend/bls12-377/groth16/commitment_test.go +++ b/backend/groth16/bls12-377/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls12-377/groth16/marshal.go b/backend/groth16/bls12-377/marshal.go similarity index 81% rename from internal/backend/bls12-377/groth16/marshal.go rename to backend/groth16/bls12-377/marshal.go index 787d450165..bec919254c 100644 --- a/internal/backend/bls12-377/groth16/marshal.go +++ b/backend/groth16/bls12-377/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls12-377/groth16/marshal_test.go b/backend/groth16/bls12-377/marshal_test.go similarity index 78% rename from internal/backend/bls12-377/groth16/marshal_test.go rename to backend/groth16/bls12-377/marshal_test.go index 901d6f885f..b5fe628704 100644 --- a/internal/backend/bls12-377/groth16/marshal_test.go +++ b/backend/groth16/bls12-377/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls12-377/mpcsetup/lagrange.go b/backend/groth16/bls12-377/mpcsetup/lagrange.go new file mode 100644 index 0000000000..92b6acb948 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls12-377/mpcsetup/marshal.go b/backend/groth16/bls12-377/mpcsetup/marshal.go new file mode 100644 index 0000000000..35ceece58f --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls12-377/mpcsetup/marshal_test.go b/backend/groth16/bls12-377/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..fafeac27ca --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark/constraint/bls12-377" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls12-377/mpcsetup/phase1.go b/backend/groth16/bls12-377/mpcsetup/phase1.go new file mode 100644 index 0000000000..573aaec996 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-377/mpcsetup/phase2.go b/backend/groth16/bls12-377/mpcsetup/phase2.go new file mode 100644 index 0000000000..3460138d12 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bls12-377" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-377/mpcsetup/setup.go b/backend/groth16/bls12-377/mpcsetup/setup.go new file mode 100644 index 0000000000..683369c871 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls12-377" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls12-377/mpcsetup/setup_test.go b/backend/groth16/bls12-377/mpcsetup/setup_test.go new file mode 100644 index 0000000000..6cfb710888 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/constraint/bls12-377" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls12-377/mpcsetup/utils.go b/backend/groth16/bls12-377/mpcsetup/utils.go new file mode 100644 index 0000000000..978b2ecbde --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls12-377/groth16/prove.go b/backend/groth16/bls12-377/prove.go similarity index 64% rename from internal/backend/bls12-377/groth16/prove.go rename to backend/groth16/bls12-377/prove.go index 4590530c00..7f52932504 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/backend/groth16/bls12-377/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls12-377" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls12-377/groth16/setup.go b/backend/groth16/bls12-377/setup.go similarity index 75% rename from internal/backend/bls12-377/groth16/setup.go rename to backend/groth16/bls12-377/setup.go index 8b3beb485e..1cd302a9a9 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/backend/groth16/bls12-377/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls12-377" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls12-377/groth16/verify.go b/backend/groth16/bls12-377/verify.go similarity index 64% rename from internal/backend/bls12-377/groth16/verify.go rename to backend/groth16/bls12-377/verify.go index 6a9819c093..3da51fcaee 100644 --- a/internal/backend/bls12-377/groth16/verify.go +++ b/backend/groth16/bls12-377/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bls12-381/groth16/commitment.go b/backend/groth16/bls12-381/commitment.go similarity index 74% rename from internal/backend/bls12-381/groth16/commitment.go rename to backend/groth16/bls12-381/commitment.go index 680f862419..6fcc533b5b 100644 --- a/internal/backend/bls12-381/groth16/commitment.go +++ b/backend/groth16/bls12-381/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls12-381/groth16/commitment_test.go b/backend/groth16/bls12-381/commitment_test.go similarity index 91% rename from internal/backend/bls12-381/groth16/commitment_test.go rename to backend/groth16/bls12-381/commitment_test.go index 9279f8449e..8f6882aeb1 100644 --- a/internal/backend/bls12-381/groth16/commitment_test.go +++ b/backend/groth16/bls12-381/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls12-381/groth16/marshal.go b/backend/groth16/bls12-381/marshal.go similarity index 81% rename from internal/backend/bls12-381/groth16/marshal.go rename to backend/groth16/bls12-381/marshal.go index 8f1cf81f4a..5a51575c78 100644 --- a/internal/backend/bls12-381/groth16/marshal.go +++ b/backend/groth16/bls12-381/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls12-381/groth16/marshal_test.go b/backend/groth16/bls12-381/marshal_test.go similarity index 78% rename from internal/backend/bls12-381/groth16/marshal_test.go rename to backend/groth16/bls12-381/marshal_test.go index 990e3ee6b1..55d8c0856f 100644 --- a/internal/backend/bls12-381/groth16/marshal_test.go +++ b/backend/groth16/bls12-381/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls12-381/mpcsetup/lagrange.go b/backend/groth16/bls12-381/mpcsetup/lagrange.go new file mode 100644 index 0000000000..8da7a42f2b --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls12-381/mpcsetup/marshal.go b/backend/groth16/bls12-381/mpcsetup/marshal.go new file mode 100644 index 0000000000..0aa64ea1f0 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls12-381/mpcsetup/marshal_test.go b/backend/groth16/bls12-381/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..9432d50a6a --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/constraint/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls12-381/mpcsetup/phase1.go b/backend/groth16/bls12-381/mpcsetup/phase1.go new file mode 100644 index 0000000000..14cef5f605 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-381/mpcsetup/phase2.go b/backend/groth16/bls12-381/mpcsetup/phase2.go new file mode 100644 index 0000000000..f2cd98bd77 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bls12-381" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-381/mpcsetup/setup.go b/backend/groth16/bls12-381/mpcsetup/setup.go new file mode 100644 index 0000000000..dd568aa21e --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls12-381" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls12-381/mpcsetup/setup_test.go b/backend/groth16/bls12-381/mpcsetup/setup_test.go new file mode 100644 index 0000000000..e801821bea --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark/constraint/bls12-381" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls12-381/mpcsetup/utils.go b/backend/groth16/bls12-381/mpcsetup/utils.go new file mode 100644 index 0000000000..e29ec7ae32 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls12-381/groth16/prove.go b/backend/groth16/bls12-381/prove.go similarity index 64% rename from internal/backend/bls12-381/groth16/prove.go rename to backend/groth16/bls12-381/prove.go index 9afd4240c1..23ab3529dc 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/backend/groth16/bls12-381/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls12-381" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls12-381/groth16/setup.go b/backend/groth16/bls12-381/setup.go similarity index 75% rename from internal/backend/bls12-381/groth16/setup.go rename to backend/groth16/bls12-381/setup.go index f1b05f3cc8..d1ca2e9b2d 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/backend/groth16/bls12-381/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls12-381" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls12-381/groth16/verify.go b/backend/groth16/bls12-381/verify.go similarity index 64% rename from internal/backend/bls12-381/groth16/verify.go rename to backend/groth16/bls12-381/verify.go index 0a9de8c797..646e052f42 100644 --- a/internal/backend/bls12-381/groth16/verify.go +++ b/backend/groth16/bls12-381/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bls24-315/groth16/commitment.go b/backend/groth16/bls24-315/commitment.go similarity index 74% rename from internal/backend/bls24-315/groth16/commitment.go rename to backend/groth16/bls24-315/commitment.go index 3cb966f8f0..fc1a3def96 100644 --- a/internal/backend/bls24-315/groth16/commitment.go +++ b/backend/groth16/bls24-315/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls24-315/groth16/commitment_test.go b/backend/groth16/bls24-315/commitment_test.go similarity index 91% rename from internal/backend/bls24-315/groth16/commitment_test.go rename to backend/groth16/bls24-315/commitment_test.go index 236bcda605..0f626448b3 100644 --- a/internal/backend/bls24-315/groth16/commitment_test.go +++ b/backend/groth16/bls24-315/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls24-315/groth16/marshal.go b/backend/groth16/bls24-315/marshal.go similarity index 81% rename from internal/backend/bls24-315/groth16/marshal.go rename to backend/groth16/bls24-315/marshal.go index ede9181c06..3da8e55772 100644 --- a/internal/backend/bls24-315/groth16/marshal.go +++ b/backend/groth16/bls24-315/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls24-315/groth16/marshal_test.go b/backend/groth16/bls24-315/marshal_test.go similarity index 78% rename from internal/backend/bls24-315/groth16/marshal_test.go rename to backend/groth16/bls24-315/marshal_test.go index 06bf1ec9de..ecdd8b4ea1 100644 --- a/internal/backend/bls24-315/groth16/marshal_test.go +++ b/backend/groth16/bls24-315/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls24-315/mpcsetup/lagrange.go b/backend/groth16/bls24-315/mpcsetup/lagrange.go new file mode 100644 index 0000000000..6fb61f2c68 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls24-315/mpcsetup/marshal.go b/backend/groth16/bls24-315/mpcsetup/marshal.go new file mode 100644 index 0000000000..f670a7af74 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls24-315/mpcsetup/marshal_test.go b/backend/groth16/bls24-315/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..9ca32105de --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark/constraint/bls24-315" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls24-315/mpcsetup/phase1.go b/backend/groth16/bls24-315/mpcsetup/phase1.go new file mode 100644 index 0000000000..cefde6c90f --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-315/mpcsetup/phase2.go b/backend/groth16/bls24-315/mpcsetup/phase2.go new file mode 100644 index 0000000000..b85efd3856 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bls24-315" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-315/mpcsetup/setup.go b/backend/groth16/bls24-315/mpcsetup/setup.go new file mode 100644 index 0000000000..98fc63f1ad --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls24-315" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls24-315/mpcsetup/setup_test.go b/backend/groth16/bls24-315/mpcsetup/setup_test.go new file mode 100644 index 0000000000..c640467623 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark/constraint/bls24-315" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls24-315/mpcsetup/utils.go b/backend/groth16/bls24-315/mpcsetup/utils.go new file mode 100644 index 0000000000..c86248ac16 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls24-315/groth16/prove.go b/backend/groth16/bls24-315/prove.go similarity index 64% rename from internal/backend/bls24-315/groth16/prove.go rename to backend/groth16/bls24-315/prove.go index cd2fa623e5..0b0774af00 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/backend/groth16/bls24-315/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls24-315" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls24-315/groth16/setup.go b/backend/groth16/bls24-315/setup.go similarity index 75% rename from internal/backend/bls24-315/groth16/setup.go rename to backend/groth16/bls24-315/setup.go index 8a80f7f023..a37ce828cd 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/backend/groth16/bls24-315/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls24-315" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls24-315/groth16/verify.go b/backend/groth16/bls24-315/verify.go similarity index 64% rename from internal/backend/bls24-315/groth16/verify.go rename to backend/groth16/bls24-315/verify.go index f0f170d9b3..6e85b70ecf 100644 --- a/internal/backend/bls24-315/groth16/verify.go +++ b/backend/groth16/bls24-315/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bls24-317/groth16/commitment.go b/backend/groth16/bls24-317/commitment.go similarity index 74% rename from internal/backend/bls24-317/groth16/commitment.go rename to backend/groth16/bls24-317/commitment.go index 00b3713dd6..05d71ba172 100644 --- a/internal/backend/bls24-317/groth16/commitment.go +++ b/backend/groth16/bls24-317/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls24-317/groth16/commitment_test.go b/backend/groth16/bls24-317/commitment_test.go similarity index 91% rename from internal/backend/bls24-317/groth16/commitment_test.go rename to backend/groth16/bls24-317/commitment_test.go index 3963ef349e..bfb2fc578e 100644 --- a/internal/backend/bls24-317/groth16/commitment_test.go +++ b/backend/groth16/bls24-317/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls24-317/groth16/marshal.go b/backend/groth16/bls24-317/marshal.go similarity index 81% rename from internal/backend/bls24-317/groth16/marshal.go rename to backend/groth16/bls24-317/marshal.go index 3b0ec45426..75deea5847 100644 --- a/internal/backend/bls24-317/groth16/marshal.go +++ b/backend/groth16/bls24-317/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls24-317/groth16/marshal_test.go b/backend/groth16/bls24-317/marshal_test.go similarity index 78% rename from internal/backend/bls24-317/groth16/marshal_test.go rename to backend/groth16/bls24-317/marshal_test.go index aa75c1d456..860e08725c 100644 --- a/internal/backend/bls24-317/groth16/marshal_test.go +++ b/backend/groth16/bls24-317/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls24-317/mpcsetup/lagrange.go b/backend/groth16/bls24-317/mpcsetup/lagrange.go new file mode 100644 index 0000000000..9d0ba07367 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls24-317/mpcsetup/marshal.go b/backend/groth16/bls24-317/mpcsetup/marshal.go new file mode 100644 index 0000000000..6e1ebfc02c --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls24-317/mpcsetup/marshal_test.go b/backend/groth16/bls24-317/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..6fc34b923c --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark/constraint/bls24-317" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls24-317/mpcsetup/phase1.go b/backend/groth16/bls24-317/mpcsetup/phase1.go new file mode 100644 index 0000000000..72e61e9977 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-317/mpcsetup/phase2.go b/backend/groth16/bls24-317/mpcsetup/phase2.go new file mode 100644 index 0000000000..cac7bc08eb --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bls24-317" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-317/mpcsetup/setup.go b/backend/groth16/bls24-317/mpcsetup/setup.go new file mode 100644 index 0000000000..f90fea816c --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls24-317" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls24-317/mpcsetup/setup_test.go b/backend/groth16/bls24-317/mpcsetup/setup_test.go new file mode 100644 index 0000000000..8332441b27 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark/constraint/bls24-317" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls24-317/mpcsetup/utils.go b/backend/groth16/bls24-317/mpcsetup/utils.go new file mode 100644 index 0000000000..877fef7fad --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls24-317/groth16/prove.go b/backend/groth16/bls24-317/prove.go similarity index 64% rename from internal/backend/bls24-317/groth16/prove.go rename to backend/groth16/bls24-317/prove.go index 5dfeaf9228..72f08ac910 100644 --- a/internal/backend/bls24-317/groth16/prove.go +++ b/backend/groth16/bls24-317/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls24-317" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls24-317/groth16/setup.go b/backend/groth16/bls24-317/setup.go similarity index 75% rename from internal/backend/bls24-317/groth16/setup.go rename to backend/groth16/bls24-317/setup.go index c49e1718ea..04156ae8bd 100644 --- a/internal/backend/bls24-317/groth16/setup.go +++ b/backend/groth16/bls24-317/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bls24-317" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls24-317/groth16/verify.go b/backend/groth16/bls24-317/verify.go similarity index 64% rename from internal/backend/bls24-317/groth16/verify.go rename to backend/groth16/bls24-317/verify.go index e6f5f33f46..3affc23113 100644 --- a/internal/backend/bls24-317/groth16/verify.go +++ b/backend/groth16/bls24-317/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bn254/groth16/commitment.go b/backend/groth16/bn254/commitment.go similarity index 74% rename from internal/backend/bn254/groth16/commitment.go rename to backend/groth16/bn254/commitment.go index d3930c0f08..435a7c058c 100644 --- a/internal/backend/bn254/groth16/commitment.go +++ b/backend/groth16/bn254/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bn254/groth16/commitment_test.go b/backend/groth16/bn254/commitment_test.go similarity index 91% rename from internal/backend/bn254/groth16/commitment_test.go rename to backend/groth16/bn254/commitment_test.go index 4573f9749b..759501ebe2 100644 --- a/internal/backend/bn254/groth16/commitment_test.go +++ b/backend/groth16/bn254/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bn254/groth16/marshal.go b/backend/groth16/bn254/marshal.go similarity index 82% rename from internal/backend/bn254/groth16/marshal.go rename to backend/groth16/bn254/marshal.go index 66e3c144a1..0a310d6d15 100644 --- a/internal/backend/bn254/groth16/marshal.go +++ b/backend/groth16/bn254/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bn254/groth16/marshal_test.go b/backend/groth16/bn254/marshal_test.go similarity index 78% rename from internal/backend/bn254/groth16/marshal_test.go rename to backend/groth16/bn254/marshal_test.go index 2db9e45415..ff9d807805 100644 --- a/internal/backend/bn254/groth16/marshal_test.go +++ b/backend/groth16/bn254/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bn254/mpcsetup/lagrange.go b/backend/groth16/bn254/mpcsetup/lagrange.go new file mode 100644 index 0000000000..cbf377dc25 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bn254/mpcsetup/marshal.go b/backend/groth16/bn254/mpcsetup/marshal.go new file mode 100644 index 0000000000..08cb2ae3d1 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bn254/mpcsetup/marshal_test.go b/backend/groth16/bn254/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..b56e2bfdce --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bn254/mpcsetup/phase1.go b/backend/groth16/bn254/mpcsetup/phase1.go new file mode 100644 index 0000000000..a912a473aa --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bn254/mpcsetup/phase2.go b/backend/groth16/bn254/mpcsetup/phase2.go new file mode 100644 index 0000000000..c14e14b509 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bn254" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bn254/mpcsetup/setup.go b/backend/groth16/bn254/mpcsetup/setup.go new file mode 100644 index 0000000000..4946e9f597 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bn254" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bn254/mpcsetup/setup_test.go b/backend/groth16/bn254/mpcsetup/setup_test.go new file mode 100644 index 0000000000..e6644c664d --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/setup_test.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/constraint/bn254" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bn254/mpcsetup/utils.go b/backend/groth16/bn254/mpcsetup/utils.go new file mode 100644 index 0000000000..e3b47d1121 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bn254/groth16/prove.go b/backend/groth16/bn254/prove.go similarity index 64% rename from internal/backend/bn254/groth16/prove.go rename to backend/groth16/bn254/prove.go index 8ca7d568d8..75eb17e895 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/backend/groth16/bn254/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bn254/groth16/setup.go b/backend/groth16/bn254/setup.go similarity index 75% rename from internal/backend/bn254/groth16/setup.go rename to backend/groth16/bn254/setup.go index e13869bbcb..d6eaf37793 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/backend/groth16/bn254/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bn254" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bn254/groth16/solidity.go b/backend/groth16/bn254/solidity.go similarity index 100% rename from internal/backend/bn254/groth16/solidity.go rename to backend/groth16/bn254/solidity.go diff --git a/backend/groth16/bn254/utils_test.go b/backend/groth16/bn254/utils_test.go new file mode 100644 index 0000000000..a5bc2a5770 --- /dev/null +++ b/backend/groth16/bn254/utils_test.go @@ -0,0 +1,38 @@ +package groth16 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/assert" +) + +func assertSliceEquals[T any](t *testing.T, expected []T, seen []T) { + assert.Equal(t, len(expected), len(seen)) + for i := range expected { + assert.Equal(t, expected[i], seen[i]) + } +} + +func TestFilterHeap(t *testing.T) { + elems := []fr.Element{{0}, {1}, {2}, {3}} + + r := filterHeap(elems, 0, []int{1, 2}) + expected := []fr.Element{{0}, {3}} + assertSliceEquals(t, expected, r) + + r = filterHeap(elems[1:], 1, []int{1, 2}) + expected = []fr.Element{{3}} + assertSliceEquals(t, expected, r) +} + +func TestFilterRepeated(t *testing.T) { + elems := []fr.Element{{0}, {1}, {2}, {3}} + r := filterHeap(elems, 0, []int{1, 1, 2}) + expected := []fr.Element{{0}, {3}} + assertSliceEquals(t, expected, r) + + r = filterHeap(elems[1:], 1, []int{1, 1, 2}) + expected = []fr.Element{{3}} + assertSliceEquals(t, expected, r) +} diff --git a/internal/backend/bn254/groth16/verify.go b/backend/groth16/bn254/verify.go similarity index 69% rename from internal/backend/bn254/groth16/verify.go rename to backend/groth16/bn254/verify.go index 917ebf9166..64c5073266 100644 --- a/internal/backend/bn254/groth16/verify.go +++ b/backend/groth16/bn254/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "text/template" "time" ) @@ -37,10 +39,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -63,21 +63,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -88,8 +99,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bw6-633/groth16/commitment.go b/backend/groth16/bw6-633/commitment.go similarity index 74% rename from internal/backend/bw6-633/groth16/commitment.go rename to backend/groth16/bw6-633/commitment.go index d1243f342c..f8af92e4fc 100644 --- a/internal/backend/bw6-633/groth16/commitment.go +++ b/backend/groth16/bw6-633/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bw6-633/groth16/commitment_test.go b/backend/groth16/bw6-633/commitment_test.go similarity index 91% rename from internal/backend/bw6-633/groth16/commitment_test.go rename to backend/groth16/bw6-633/commitment_test.go index b7f700d2cb..a832752ffd 100644 --- a/internal/backend/bw6-633/groth16/commitment_test.go +++ b/backend/groth16/bw6-633/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bw6-633/groth16/marshal.go b/backend/groth16/bw6-633/marshal.go similarity index 82% rename from internal/backend/bw6-633/groth16/marshal.go rename to backend/groth16/bw6-633/marshal.go index 8c09e3a403..edca7ba642 100644 --- a/internal/backend/bw6-633/groth16/marshal.go +++ b/backend/groth16/bw6-633/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bw6-633/groth16/marshal_test.go b/backend/groth16/bw6-633/marshal_test.go similarity index 78% rename from internal/backend/bw6-633/groth16/marshal_test.go rename to backend/groth16/bw6-633/marshal_test.go index b2a3a63d4b..9e1c0f0728 100644 --- a/internal/backend/bw6-633/groth16/marshal_test.go +++ b/backend/groth16/bw6-633/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bw6-633/mpcsetup/lagrange.go b/backend/groth16/bw6-633/mpcsetup/lagrange.go new file mode 100644 index 0000000000..563daf9891 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bw6-633/mpcsetup/marshal.go b/backend/groth16/bw6-633/mpcsetup/marshal.go new file mode 100644 index 0000000000..4955ec6d92 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bw6-633/mpcsetup/marshal_test.go b/backend/groth16/bw6-633/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..0c04d15934 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark/constraint/bw6-633" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bw6-633/mpcsetup/phase1.go b/backend/groth16/bw6-633/mpcsetup/phase1.go new file mode 100644 index 0000000000..fa91cc5025 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-633/mpcsetup/phase2.go b/backend/groth16/bw6-633/mpcsetup/phase2.go new file mode 100644 index 0000000000..e966fd021a --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bw6-633" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-633/mpcsetup/setup.go b/backend/groth16/bw6-633/mpcsetup/setup.go new file mode 100644 index 0000000000..35bfdca0e9 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bw6-633" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bw6-633/mpcsetup/setup_test.go b/backend/groth16/bw6-633/mpcsetup/setup_test.go new file mode 100644 index 0000000000..6933f11050 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark/constraint/bw6-633" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bw6-633/mpcsetup/utils.go b/backend/groth16/bw6-633/mpcsetup/utils.go new file mode 100644 index 0000000000..7a2979c53c --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bw6-633/groth16/prove.go b/backend/groth16/bw6-633/prove.go similarity index 64% rename from internal/backend/bw6-633/groth16/prove.go rename to backend/groth16/bw6-633/prove.go index 335a9f0ccb..41002c3f48 100644 --- a/internal/backend/bw6-633/groth16/prove.go +++ b/backend/groth16/bw6-633/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bw6-633" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bw6-633/groth16/setup.go b/backend/groth16/bw6-633/setup.go similarity index 75% rename from internal/backend/bw6-633/groth16/setup.go rename to backend/groth16/bw6-633/setup.go index d6777090eb..92c73871b7 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/backend/groth16/bw6-633/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bw6-633" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bw6-633/groth16/verify.go b/backend/groth16/bw6-633/verify.go similarity index 64% rename from internal/backend/bw6-633/groth16/verify.go rename to backend/groth16/bw6-633/verify.go index 642bedc2d4..c32a71e6a6 100644 --- a/internal/backend/bw6-633/groth16/verify.go +++ b/backend/groth16/bw6-633/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bw6-761/groth16/commitment.go b/backend/groth16/bw6-761/commitment.go similarity index 74% rename from internal/backend/bw6-761/groth16/commitment.go rename to backend/groth16/bw6-761/commitment.go index c332669a12..5c357c24ad 100644 --- a/internal/backend/bw6-761/groth16/commitment.go +++ b/backend/groth16/bw6-761/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bw6-761/groth16/commitment_test.go b/backend/groth16/bw6-761/commitment_test.go similarity index 91% rename from internal/backend/bw6-761/groth16/commitment_test.go rename to backend/groth16/bw6-761/commitment_test.go index ccddf6be25..f16cb25786 100644 --- a/internal/backend/bw6-761/groth16/commitment_test.go +++ b/backend/groth16/bw6-761/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bw6-761/groth16/marshal.go b/backend/groth16/bw6-761/marshal.go similarity index 82% rename from internal/backend/bw6-761/groth16/marshal.go rename to backend/groth16/bw6-761/marshal.go index 0a18c29463..6f072bc90e 100644 --- a/internal/backend/bw6-761/groth16/marshal.go +++ b/backend/groth16/bw6-761/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bw6-761/groth16/marshal_test.go b/backend/groth16/bw6-761/marshal_test.go similarity index 78% rename from internal/backend/bw6-761/groth16/marshal_test.go rename to backend/groth16/bw6-761/marshal_test.go index bccf077d51..36e5f692a4 100644 --- a/internal/backend/bw6-761/groth16/marshal_test.go +++ b/backend/groth16/bw6-761/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bw6-761/mpcsetup/lagrange.go b/backend/groth16/bw6-761/mpcsetup/lagrange.go new file mode 100644 index 0000000000..3b8aaa3b7c --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bw6-761/mpcsetup/marshal.go b/backend/groth16/bw6-761/mpcsetup/marshal.go new file mode 100644 index 0000000000..d1a071aa5d --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bw6-761/mpcsetup/marshal_test.go b/backend/groth16/bw6-761/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..46913c86b5 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bw6-761/mpcsetup/phase1.go b/backend/groth16/bw6-761/mpcsetup/phase1.go new file mode 100644 index 0000000000..07af196d3f --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-761/mpcsetup/phase2.go b/backend/groth16/bw6-761/mpcsetup/phase2.go new file mode 100644 index 0000000000..df2d2d5be9 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bw6-761" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-761/mpcsetup/setup.go b/backend/groth16/bw6-761/mpcsetup/setup.go new file mode 100644 index 0000000000..9008fd26b0 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bw6-761" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bw6-761/mpcsetup/setup_test.go b/backend/groth16/bw6-761/mpcsetup/setup_test.go new file mode 100644 index 0000000000..28e4c7a768 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/constraint/bw6-761" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bw6-761/mpcsetup/utils.go b/backend/groth16/bw6-761/mpcsetup/utils.go new file mode 100644 index 0000000000..dfdd1e8a97 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bw6-761/groth16/prove.go b/backend/groth16/bw6-761/prove.go similarity index 64% rename from internal/backend/bw6-761/groth16/prove.go rename to backend/groth16/bw6-761/prove.go index 950932327f..7234bf7d15 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/backend/groth16/bw6-761/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,71 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +207,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +300,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +348,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +368,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bw6-761/groth16/setup.go b/backend/groth16/bw6-761/setup.go similarity index 75% rename from internal/backend/bw6-761/groth16/setup.go rename to backend/groth16/bw6-761/setup.go index c0616ff53a..5a69e9799a 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/backend/groth16/bw6-761/setup.go @@ -17,11 +17,13 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/bw6-761" "math/big" @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bw6-761/groth16/verify.go b/backend/groth16/bw6-761/verify.go similarity index 64% rename from internal/backend/bw6-761/groth16/verify.go rename to backend/groth16/bw6-761/verify.go index d1644cdb3b..ca49685a5e 100644 --- a/internal/backend/bw6-761/groth16/verify.go +++ b/backend/groth16/bw6-761/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index 1aefdfa072..41e0f63c7e 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -44,13 +44,13 @@ import ( gnarkio "github.com/consensys/gnark/io" - groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - groth16_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/groth16" - groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - groth16_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/groth16" - groth16_bn254 "github.com/consensys/gnark/internal/backend/bn254/groth16" - groth16_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/groth16" - groth16_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/groth16" + groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377" + groth16_bls12381 "github.com/consensys/gnark/backend/groth16/bls12-381" + groth16_bls24315 "github.com/consensys/gnark/backend/groth16/bls24-315" + groth16_bls24317 "github.com/consensys/gnark/backend/groth16/bls24-317" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + groth16_bw6633 "github.com/consensys/gnark/backend/groth16/bw6-633" + groth16_bw6761 "github.com/consensys/gnark/backend/groth16/bw6-761" ) type groth16Object interface { @@ -168,55 +168,28 @@ func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error { // internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { - // apply options - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return nil, err - } - switch _r1cs := r1cs.(type) { case *cs_bls12377.R1CS: - w, ok := fullWitness.Vector().(fr_bls12377.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), fullWitness, opts...) + case *cs_bls12381.R1CS: - w, ok := fullWitness.Vector().(fr_bls12381.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), fullWitness, opts...) + case *cs_bn254.R1CS: - w, ok := fullWitness.Vector().(fr_bn254.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), fullWitness, opts...) + case *cs_bw6761.R1CS: - w, ok := fullWitness.Vector().(fr_bw6761.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), fullWitness, opts...) + case *cs_bls24317.R1CS: - w, ok := fullWitness.Vector().(fr_bls24317.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), w, opt) + return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), fullWitness, opts...) + case *cs_bls24315.R1CS: - w, ok := fullWitness.Vector().(fr_bls24315.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), fullWitness, opts...) + case *cs_bw6633.R1CS: - w, ok := fullWitness.Vector().(fr_bw6633.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), w, opt) + return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), fullWitness, opts...) + default: panic("unrecognized R1CS curve type") } diff --git a/backend/groth16/internal/test_utils/test_utils.go b/backend/groth16/internal/test_utils/test_utils.go new file mode 100644 index 0000000000..4a76414298 --- /dev/null +++ b/backend/groth16/internal/test_utils/test_utils.go @@ -0,0 +1,14 @@ +package test_utils + +import "math/rand" + +func Random2DIntSlice(maxN, maxM int) [][]int { + res := make([][]int, rand.Intn(maxN)) //#nosec G404 weak rng OK for test + for i := range res { + res[i] = make([]int, rand.Intn(maxM)) //#nosec G404 weak rng OK for test + for j := range res[i] { + res[i][j] = rand.Int() //#nosec G404 weak rng OK for test + } + } + return res +} diff --git a/backend/groth16/internal/utils.go b/backend/groth16/internal/utils.go new file mode 100644 index 0000000000..6062ef57ce --- /dev/null +++ b/backend/groth16/internal/utils.go @@ -0,0 +1,22 @@ +package internal + +func ConcatAll(slices ...[]int) []int { // copyright note: written by GitHub Copilot + totalLen := 0 + for _, s := range slices { + totalLen += len(s) + } + res := make([]int, totalLen) + i := 0 + for _, s := range slices { + i += copy(res[i:], s) + } + return res +} + +func NbElements(slices [][]int) int { // copyright note: written by GitHub Copilot + totalLen := 0 + for _, s := range slices { + totalLen += len(s) + } + return totalLen +} diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go deleted file mode 100644 index 4d0ee814d3..0000000000 --- a/backend/hint/builtin.go +++ /dev/null @@ -1,25 +0,0 @@ -package hint - -import ( - "math/big" -) - -func init() { - Register(InvZero) -} - -// InvZero computes the value 1/a for the single input a. If a == 0, returns 0. -func InvZero(q *big.Int, inputs []*big.Int, results []*big.Int) error { - result := results[0] - - // save input - result.Set(inputs[0]) - - // a == 0, return - if result.IsUint64() && result.Uint64() == 0 { - return nil - } - - result.ModInverse(result, q) - return nil -} diff --git a/backend/hint/hint.go b/backend/hint/hint.go deleted file mode 100644 index 60976a3dc8..0000000000 --- a/backend/hint/hint.go +++ /dev/null @@ -1,102 +0,0 @@ -/* -Package hint allows to define computations outside of a circuit. - -Usually, it is expected that computations in circuits are performed on -variables. However, in some cases defining the computations in circuits may be -complicated or computationally expensive. By using hints, the computations are -performed outside of the circuit on integers (compared to the frontend.Variable -values inside the circuits) and the result of a hint function is assigned to a -newly created variable in a circuit. - -As the computations are perfomed outside of the circuit, then the correctness of -the result is not guaranteed. This also means that the result of a hint function -is unconstrained by default, leading to failure while composing circuit proof. -Thus, it is the circuit developer responsibility to verify the correctness hint -result by adding necessary constraints in the circuit. - -As an example, lets say the hint function computes a factorization of a -semiprime n: - - p, q <- hint(n) st. p * q = n - -into primes p and q. Then, the circuit developer needs to assert in the circuit -that p*q indeed equals to n: - - n == p * q. - -However, if the hint function is incorrectly defined (e.g. in the previous -example, it returns 1 and n instead of p and q), then the assertion may still -hold, but the constructed proof is semantically invalid. Thus, the user -constructing the proof must be extremely cautious when using hints. - -# Using hint functions in circuits - -To use a hint function in a circuit, the developer first needs to define a hint -function hintFn according to the Function interface. Then, in a circuit, the -developer applies the hint function with frontend.API.NewHint(hintFn, vars...), -where vars are the variables the hint function will be applied to (and -correspond to the argument inputs in the Function type) which returns a new -unconstrained variable. The returned variables must be constrained using -frontend.API.Assert[.*] methods. - -As explained, the hints are essentially black boxes from the circuit point of -view and thus the defined hints in circuits are not used when constructing a -proof. To allow the particular hint functions to be used during proof -construction, the user needs to supply a backend.ProverOption indicating the -enabled hints. Such options can be optained by a call to -backend.WithHints(hintFns...), where hintFns are the corresponding hint -functions. - -# Using hint functions in gadgets - -Similar considerations apply for hint functions used in gadgets as in -user-defined circuits. However, listing all hint functions used in a particular -gadget for constructing backend.ProverOption puts high overhead for the user to -enable all necessary hints. - -For that, this package also provides a registry of trusted hint functions. When -a gadget registers a hint function, then it is automatically enabled during -proof computation and the prover does not need to provide a corresponding -proving option. - -In the init() method of the gadget, call the method Register(hintFn) method on -the hint function hintFn to register a hint function in the package registry. -*/ -package hint - -import ( - "hash/fnv" - "math/big" - "reflect" - "runtime" -) - -// ID is a unique identifier for a hint function used for lookup. -type ID uint32 - -// Function defines an annotated hint function; the number of inputs and outputs injected at solving -// time is defined in the circuit (compile time). -// -// For example: -// -// b := api.NewHint(hint, 2, a) -// --> at solving time, hint is going to be invoked with 1 input (a) and is expected to return 2 outputs -// b[0] and b[1]. -type Function func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error - -// UUID is a reference function for computing the hint ID based on a function name -func UUID(fn Function) ID { - hf := fnv.New32a() - name := Name(fn) - - // TODO relying on name to derive UUID is risky; if fn is an anonymous func, wil be package.glob..funcN - // and if new anonymous functions are added in the package, N may change, so will UUID. - hf.Write([]byte(name)) // #nosec G104 -- does not err - - return ID(hf.Sum32()) -} - -func Name(fn Function) string { - fnptr := reflect.ValueOf(fn).Pointer() - return runtime.FuncForPC(fnptr).Name() -} diff --git a/backend/hint/registry.go b/backend/hint/registry.go deleted file mode 100644 index dfd5a66b8d..0000000000 --- a/backend/hint/registry.go +++ /dev/null @@ -1,37 +0,0 @@ -package hint - -import ( - "sync" - - "github.com/consensys/gnark/logger" -) - -var registry = make(map[ID]Function) -var registryM sync.RWMutex - -// Register registers a hint function in the global registry. -func Register(hintFns ...Function) { - registryM.Lock() - defer registryM.Unlock() - for _, hintFn := range hintFns { - key := UUID(hintFn) - name := Name(hintFn) - if _, ok := registry[key]; ok { - log := logger.Logger() - log.Warn().Str("name", name).Msg("function registered multiple times") - return - } - registry[key] = hintFn - } -} - -// GetRegistered returns all registered hint functions. -func GetRegistered() []Function { - registryM.RLock() - defer registryM.RUnlock() - ret := make([]Function, 0, len(registry)) - for _, v := range registry { - ret = append(ret, v) - } - return ret -} diff --git a/internal/backend/bls12-377/plonk/marshal.go b/backend/plonk/bls12-377/marshal.go similarity index 67% rename from internal/backend/bls12-377/plonk/marshal.go rename to backend/plonk/bls12-377/marshal.go index 64daa4c899..5066e4edd2 100644 --- a/internal/backend/bls12-377/plonk/marshal.go +++ b/backend/plonk/bls12-377/marshal.go @@ -19,12 +19,15 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - "errors" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" "io" ) -// WriteTo writes binary encoding of Proof to w without point compression +// WriteRawTo writes binary encoding of Proof to w without point compression func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { return proof.writeTo(w, curve.RawEncoding()) } @@ -49,6 +52,7 @@ func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64 proof.BatchedProof.ClaimedValues, &proof.ZShiftedOpening.H, &proof.ZShiftedOpening.ClaimedValue, + proof.Bsb22Commitments, } for _, v := range toEncode { @@ -75,6 +79,7 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { &proof.BatchedProof.ClaimedValues, &proof.ZShiftedOpening.H, &proof.ZShiftedOpening.ClaimedValue, + &proof.Bsb22Commitments, } for _, v := range toDecode { @@ -83,6 +88,10 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { } } + if proof.Bsb22Commitments == nil { + proof.Bsb22Commitments = []kzg.Digest{} + } + return dec.BytesRead(), nil } @@ -107,8 +116,15 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } n += n2 + // KZG key + n2, err = pk.Kzg.WriteTo(w) + if err != nil { + return + } + n += n2 + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { + if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,16 +133,17 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { // encode the size (nor does it convert from Montgomery to Regular form) // so we explicitly transmit []fr.Element toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, + pk.trace.Ql.Coefficients(), + pk.trace.Qr.Coefficients(), + pk.trace.Qm.Coefficients(), + pk.trace.Qo.Coefficients(), + pk.trace.Qk.Coefficients(), + coefficients(pk.trace.Qcp), + pk.lQk.Coefficients(), + pk.trace.S1.Coefficients(), + pk.trace.S2.Coefficients(), + pk.trace.S3.Coefficients(), + pk.trace.S, } for _, v := range toEncode { @@ -158,20 +175,30 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) + n2, err = pk.Kzg.ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + + pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) + + var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element + var qcp [][]fr.Element toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, + &ql, + &qr, + &qm, + &qo, + &qk, + &qcp, + &lqk, + &s1, + &s2, + &s3, + &pk.trace.S, } for _, v := range toDecode { @@ -180,6 +207,23 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { } } + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + pk.trace.Ql = iop.NewPolynomial(&ql, canReg) + pk.trace.Qr = iop.NewPolynomial(&qr, canReg) + pk.trace.Qm = iop.NewPolynomial(&qm, canReg) + pk.trace.Qo = iop.NewPolynomial(&qo, canReg) + pk.trace.Qk = iop.NewPolynomial(&qk, canReg) + pk.trace.S1 = iop.NewPolynomial(&s1, canReg) + pk.trace.S2 = iop.NewPolynomial(&s2, canReg) + pk.trace.S3 = iop.NewPolynomial(&s3, canReg) + + pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) + for i := range qcp { + pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) + } + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + pk.lQk = iop.NewPolynomial(&lqk, lagReg) + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil @@ -188,6 +232,15 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { // WriteTo writes binary encoding of VerifyingKey to w func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { + return vk.writeTo(w) +} + +// WriteRawTo writes binary encoding of VerifyingKey to w without point compression +func (vk *VerifyingKey) WriteRawTo(w io.Writer) (int64, error) { + return vk.writeTo(w, curve.RawEncoding()) +} + +func (vk *VerifyingKey) writeTo(w io.Writer, options ...func(*curve.Encoder)) (n int64, err error) { enc := curve.NewEncoder(w) toEncode := []interface{}{ @@ -204,6 +257,11 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.Qm, &vk.Qo, &vk.Qk, + vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + vk.CommitmentConstraintIndexes, } for _, v := range toEncode { @@ -232,6 +290,11 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.Qm, &vk.Qo, &vk.Qk, + &vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + &vk.CommitmentConstraintIndexes, } for _, v := range toDecode { @@ -240,5 +303,9 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { } } + if vk.Qcp == nil { + vk.Qcp = []kzg.Digest{} + } + return dec.BytesRead(), nil } diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/backend/plonk/bls12-377/marshal_test.go similarity index 54% rename from internal/backend/bls12-377/plonk/marshal_test.go rename to backend/plonk/bls12-377/marshal_test.go index 94a69f6295..5f5b257510 100644 --- a/internal/backend/bls12-377/plonk/marshal_test.go +++ b/backend/plonk/bls12-377/marshal_test.go @@ -23,6 +23,7 @@ import ( "bytes" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" gnarkio "github.com/consensys/gnark/io" "io" "math/big" @@ -106,67 +107,114 @@ func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io. } func (pk *ProvingKey) randomize() { + var vk VerifyingKey vk.randomize() pk.Vk = &vk pk.Domain[0] = *fft.NewDomain(42) pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Kzg.G1 = make([]curve.G1Affine, 7) + for i := range pk.Kzg.G1 { + pk.Kzg.G1[i] = randomG1Point() + } + n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 + ql := randomScalars(n) + qr := randomScalars(n) + qm := randomScalars(n) + qo := randomScalars(n) + qk := randomScalars(n) + lqk := randomScalars(n) + s1 := randomScalars(n) + s2 := randomScalars(n) + s3 := randomScalars(n) + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + pk.trace.Ql = iop.NewPolynomial(&ql, canReg) + pk.trace.Qr = iop.NewPolynomial(&qr, canReg) + pk.trace.Qm = iop.NewPolynomial(&qm, canReg) + pk.trace.Qo = iop.NewPolynomial(&qo, canReg) + pk.trace.Qk = iop.NewPolynomial(&qk, canReg) + pk.trace.S1 = iop.NewPolynomial(&s1, canReg) + pk.trace.S2 = iop.NewPolynomial(&s2, canReg) + pk.trace.S3 = iop.NewPolynomial(&s3, canReg) + + pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here + for i := range pk.trace.Qcp { + qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here + pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) + } + + pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) + pk.trace.S[0] = -12 + pk.trace.S[len(pk.trace.S)-1] = 8888 + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + pk.lQk = iop.NewPolynomial(&lqk, lagReg) pk.computeLagrangeCosetPolys() } func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() + vk.Size = rand.Uint64() //#nosec G404 weak rng is fine here vk.SizeInv.SetRandom() vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() + vk.NbPublicVariables = rand.Uint64() //#nosec G404 weak rng is fine here + vk.CommitmentConstraintIndexes = []uint64{rand.Uint64()} //#nosec G404 weak rng is fine here vk.CosetShift.SetRandom() - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() + vk.S[0] = randomG1Point() + vk.S[1] = randomG1Point() + vk.S[2] = randomG1Point() + + vk.Kzg.G1 = randomG1Point() + vk.Kzg.G2[0] = randomG2Point() + vk.Kzg.G2[1] = randomG2Point() + + vk.Ql = randomG1Point() + vk.Qr = randomG1Point() + vk.Qm = randomG1Point() + vk.Qo = randomG1Point() + vk.Qk = randomG1Point() + vk.Qcp = randomG1Points(rand.Intn(4)) //#nosec G404 weak rng is fine here } func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() + proof.LRO[0] = randomG1Point() + proof.LRO[1] = randomG1Point() + proof.LRO[2] = randomG1Point() + proof.Z = randomG1Point() + proof.H[0] = randomG1Point() + proof.H[1] = randomG1Point() + proof.H[2] = randomG1Point() + proof.BatchedProof.H = randomG1Point() proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() + proof.ZShiftedOpening.H = randomG1Point() proof.ZShiftedOpening.ClaimedValue.SetRandom() + proof.Bsb22Commitments = randomG1Points(rand.Intn(4)) //#nosec G404 weak rng is fine here +} + +func randomG2Point() curve.G2Affine { + _, _, _, r := curve.Generators() + r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) //#nosec G404 weak rng is fine here + return r } -func randomPoint() curve.G1Affine { +func randomG1Point() curve.G1Affine { _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) + r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) //#nosec G404 weak rng is fine here return r } +func randomG1Points(n int) []curve.G1Affine { + res := make([]curve.G1Affine, n) + for i := range res { + res[i] = randomG1Point() + } + return res +} + func randomScalars(n int) []fr.Element { v := make([]fr.Element, n) one := fr.One() diff --git a/backend/plonk/bls12-377/prove.go b/backend/plonk/bls12-377/prove.go new file mode 100644 index 0000000000..43e999fd87 --- /dev/null +++ b/backend/plonk/bls12-377/prove.go @@ -0,0 +1,721 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package plonk + +import ( + "crypto/sha256" + "math/big" + "runtime" + "sync" + "time" + + "github.com/consensys/gnark/backend/witness" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + "github.com/consensys/gnark/constraint/bls12-377" + + "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" +) + +type Proof struct { + + // Commitments to the solution vectors + LRO [3]kzg.Digest + + // Commitment to Z, the permutation polynomial + Z kzg.Digest + + // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial + H [3]kzg.Digest + + Bsb22Commitments []kzg.Digest + + // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2, qCPrime + BatchedProof kzg.BatchOpeningProof + + // Opening proof of Z at zeta*mu + ZShiftedOpening kzg.OpeningProof +} + +// Computing and verifying Bsb22 multi-commits explained in https://hackmd.io/x8KsadW3RRyX7YTCFJIkHg +func bsb22ComputeCommitmentHint(spr *cs.SparseR1CS, pk *ProvingKey, proof *Proof, cCommitments []*iop.Polynomial, res *fr.Element, commDepth int) solver.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] + committedValues := make([]fr.Element, pk.Domain[0].Cardinality) + offset := spr.GetNbPublicVariables() + for i := range ins { + committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) + } + var ( + err error + hashRes []fr.Element + ) + if _, err = committedValues[offset+commitmentInfo.CommitmentIndex].SetRandom(); err != nil { // Commitment injection constraint has qcp = 0. Safe to use for blinding. + return err + } + if _, err = committedValues[offset+spr.GetNbConstraints()-1].SetRandom(); err != nil { // Last constraint has qcp = 0. Safe to use for blinding + return err + } + pi2iop := iop.NewPolynomial(&committedValues, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}) + cCommitments[commDepth] = pi2iop.ShallowClone() + cCommitments[commDepth].ToCanonical(&pk.Domain[0]).ToRegular() + if proof.Bsb22Commitments[commDepth], err = kzg.Commit(cCommitments[commDepth].Coefficients(), pk.Kzg); err != nil { + return err + } + if hashRes, err = fr.Hash(proof.Bsb22Commitments[commDepth].Marshal(), []byte("BSB22-Plonk"), 1); err != nil { + return err + } + res.Set(&hashRes[0]) // TODO @Tabaie use CommitmentIndex for this; create a new variable CommitmentConstraintIndex for other uses + res.BigInt(outs[0]) + + return nil + } +} + +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + + log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", spr.GetNbConstraints()).Str("backend", "plonk").Logger() + + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + start := time.Now() + // pick a hash function that will be used to derive the challenges + hFunc := sha256.New() + + // create a transcript manager to apply Fiat Shamir + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") + + // result + proof := &Proof{} + + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + commitmentVal := make([]fr.Element, len(commitmentInfo)) // TODO @Tabaie get rid of this + cCommitments := make([]*iop.Polynomial, len(commitmentInfo)) + proof.Bsb22Commitments = make([]kzg.Digest, len(commitmentInfo)) + for i := range commitmentInfo { + opt.SolverOpts = append(opt.SolverOpts, solver.OverrideHint(commitmentInfo[i].HintID, + bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) + } + + // query l, r, o in Lagrange basis, not blinded + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err + } + // TODO @gbotrel deal with that conversion lazily + lcCommitments := make([]*iop.Polynomial, len(cCommitments)) + for i := range cCommitments { + lcCommitments[i] = cCommitments[i].Clone(int(pk.Domain[1].Cardinality)).ToLagrangeCoset(&pk.Domain[1]) // lagrange coset form + } + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + // l, r, o and blinded versions + var ( + wliop, + wriop, + woiop, + bwliop, + bwriop, + bwoiop *iop.Polynomial + ) + var wgLRO sync.WaitGroup + wgLRO.Add(3) + go func() { + // we keep in lagrange regular form since iop.BuildRatioCopyConstraint prefers it in this form. + wliop = iop.NewPolynomial(&evaluationLDomainSmall, lagReg) + // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. + bwliop = wliop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + wriop = iop.NewPolynomial(&evaluationRDomainSmall, lagReg) + bwriop = wriop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + woiop = iop.NewPolynomial(&evaluationODomainSmall, lagReg) + bwoiop = woiop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } + + // start computing lcqk + var lcqk *iop.Polynomial + chLcqk := make(chan struct{}, 1) + go func() { + // compute qk in canonical basis, completed with the public inputs + // We copy the coeffs of qk to pk is not mutated + lqkcoef := pk.lQk.Coefficients() + qkCompletedCanonical := make([]fr.Element, len(lqkcoef)) + copy(qkCompletedCanonical, fw[:len(spr.Public)]) + copy(qkCompletedCanonical[len(spr.Public):], lqkcoef[len(spr.Public):]) + for i := range commitmentInfo { + qkCompletedCanonical[spr.GetNbPublicVariables()+commitmentInfo[i].CommitmentIndex] = commitmentVal[i] + } + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + lcqk = iop.NewPolynomial(&qkCompletedCanonical, canReg) + lcqk.ToLagrangeCoset(&pk.Domain[1]) + close(chLcqk) + }() + + // The first challenge is derived using the public data: the commitments to the permutation, + // the coefficients of the circuit, and the public inputs. + // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) + if err := bindPublicData(&fs, "gamma", *pk.Vk, fw[:len(spr.Public)], proof.Bsb22Commitments); err != nil { + return nil, err + } + + // wait for polys to be blinded + wgLRO.Wait() + if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Kzg); err != nil { + return nil, err + } + + gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) // TODO @Tabaie @ThomasPiellard add BSB commitment here? + if err != nil { + return nil, err + } + + // Fiat Shamir this + bbeta, err := fs.ComputeChallenge("beta") + if err != nil { + return nil, err + } + var beta fr.Element + beta.SetBytes(bbeta) + + // l, r, o are already blinded + wgLRO.Add(3) + go func() { + bwliop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwriop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwoiop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + + // compute the copy constraint's ratio + // note that wliop, wriop and woiop are fft'ed (mutated) in the process. + ziop, err := iop.BuildRatioCopyConstraint( + []*iop.Polynomial{ + wliop, + wriop, + woiop, + }, + pk.trace.S, + beta, + gamma, + iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, + &pk.Domain[0], + ) + if err != nil { + return proof, err + } + + // commit to the blinded version of z + chZ := make(chan error, 1) + var bwziop, bwsziop *iop.Polynomial + var alpha fr.Element + go func() { + bwziop = ziop // iop.NewWrappedPolynomial(&ziop) + bwziop.Blind(2) + proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Kzg, runtime.NumCPU()*2) + if err != nil { + chZ <- err + } + + // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) + alpha, err = deriveRandomness(&fs, "alpha", &proof.Z) + if err != nil { + chZ <- err + } + + // Store z(g*x), without reallocating a slice + bwsziop = bwziop.ShallowClone().Shift(1) + bwsziop.ToLagrangeCoset(&pk.Domain[1]) + chZ <- nil + close(chZ) + }() + + // Full capture using latest gnark crypto... + fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element, pi2QcPrime []fr.Element) fr.Element { // TODO @Tabaie make use of the fact that qCPrime is a selector: sparse and binary + + var ic, tmp fr.Element + + ic.Mul(&fql, &l) + tmp.Mul(&fqr, &r) + ic.Add(&ic, &tmp) + tmp.Mul(&fqm, &l).Mul(&tmp, &r) + ic.Add(&ic, &tmp) + tmp.Mul(&fqo, &o) + ic.Add(&ic, &tmp).Add(&ic, &fqk) + nbComms := len(commitmentInfo) + for i := range commitmentInfo { + tmp.Mul(&pi2QcPrime[i], &pi2QcPrime[i+nbComms]) + ic.Add(&ic, &tmp) + } + + return ic + } + + fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { + u := &pk.Domain[0].FrMultiplicativeGen + var a, b, tmp fr.Element + b.Mul(&beta, &fid) + a.Add(&b, &l).Add(&a, &gamma) + b.Mul(&b, u) + tmp.Add(&b, &r).Add(&tmp, &gamma) + a.Mul(&a, &tmp) + tmp.Mul(&b, u).Add(&tmp, &o).Add(&tmp, &gamma) + a.Mul(&a, &tmp).Mul(&a, &fz) + + b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) + tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) + b.Mul(&b, &tmp) + tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) + b.Mul(&b, &tmp).Mul(&b, &fzs) + + b.Sub(&b, &a) + + return b + } + + fone := func(fz, flone fr.Element) fr.Element { + one := fr.One() + one.Sub(&fz, &one).Mul(&one, &flone) + return one + } + + // 0 , 1 , 2, 3 , 4 , 5 , 6 , 7, 8 , 9 , 10, 11, 12, 13, 14, 15:15+nbComm , 15+nbComm:15+2×nbComm + // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk ,lone, Bsb22Commitments, qCPrime + fm := func(x ...fr.Element) fr.Element { + + a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2], x[15:]) + b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) + c := fone(x[7], x[14]) + + c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) + + return c + } + + // wait for lcqk + <-chLcqk + + // wait for Z part + if err := <-chZ; err != nil { + return proof, err + } + + // wait for l, r o lagrange coset conversion + wgLRO.Wait() + + toEval := []*iop.Polynomial{ + bwliop, + bwriop, + bwoiop, + pk.lcIdIOP, + pk.lcS1, + pk.lcS2, + pk.lcS3, + bwziop, + bwsziop, + pk.lcQl, + pk.lcQr, + pk.lcQm, + pk.lcQo, + lcqk, + pk.lLoneIOP, + } + toEval = append(toEval, lcCommitments...) // TODO: Add this at beginning + toEval = append(toEval, pk.lcQcp...) + systemEvaluation, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, toEval...) + if err != nil { + return nil, err + } + // open blinded Z at zeta*z + chbwzIOP := make(chan struct{}, 1) + go func() { + bwziop.ToCanonical(&pk.Domain[1]).ToRegular() + close(chbwzIOP) + }() + + h, err := iop.DivideByXMinusOne(systemEvaluation, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) // TODO Rename to DivideByXNMinusOne or DivideByVanishingPoly etc + if err != nil { + return nil, err + } + + // compute kzg commitments of h1, h2 and h3 + if err := commitToQuotient( + h.Coefficients()[:pk.Domain[0].Cardinality+2], + h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], + h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], + proof, pk.Kzg); err != nil { + return nil, err + } + + // derive zeta + zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) + if err != nil { + return nil, err + } + + // compute evaluations of (blinded version of) l, r, o, z, qCPrime at zeta + var blzeta, brzeta, bozeta fr.Element + qcpzeta := make([]fr.Element, len(commitmentInfo)) + + var wgEvals sync.WaitGroup + wgEvals.Add(3) + evalAtZeta := func(poly *iop.Polynomial, res *fr.Element) { + poly.ToCanonical(&pk.Domain[1]).ToRegular() + *res = poly.Evaluate(zeta) + wgEvals.Done() + } + go evalAtZeta(bwliop, &blzeta) + go evalAtZeta(bwriop, &brzeta) + go evalAtZeta(bwoiop, &bozeta) + evalQcpAtZeta := func(begin, end int) { + for i := begin; i < end; i++ { + qcpzeta[i] = pk.trace.Qcp[i].Evaluate(zeta) + } + } + utils.Parallelize(len(commitmentInfo), evalQcpAtZeta) + + var zetaShifted fr.Element + zetaShifted.Mul(&zeta, &pk.Vk.Generator) + <-chbwzIOP + proof.ZShiftedOpening, err = kzg.Open( + bwziop.Coefficients()[:bwziop.BlindedSize()], + zetaShifted, + pk.Kzg, + ) + if err != nil { + return nil, err + } + + // start to compute foldedH and foldedHDigest while computeLinearizedPolynomial runs. + computeFoldedH := make(chan struct{}, 1) + var foldedH []fr.Element + var foldedHDigest kzg.Digest + go func() { + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) + var bZetaPowerm, bSize big.Int + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + var zetaPowerm fr.Element + zetaPowerm.Exp(zeta, &bSize) + zetaPowerm.BigInt(&bZetaPowerm) + foldedHDigest = proof.H[2] + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) + + // foldedH = h1 + ζ*h2 + ζ²*h3 + foldedH = h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] + h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] + utils.Parallelize(len(foldedH), func(start, end int) { + for i := start; i < end; i++ { + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 + } + }) + close(computeFoldedH) + }() + + wgEvals.Wait() // wait for the evaluations + + var ( + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error + ) + + // blinded z evaluated at u*zeta + bzuzeta := proof.ZShiftedOpening.ClaimedValue + + // compute the linearization polynomial r at zeta + // (goal: save committing separately to z, ql, qr, qm, qo, k + // note: we linearizedPolynomialCanonical reuses bwziop memory + linearizedPolynomialCanonical = computeLinearizedPolynomial( + blzeta, + brzeta, + bozeta, + alpha, + beta, + gamma, + zeta, + bzuzeta, + qcpzeta, + bwziop.Coefficients()[:bwziop.BlindedSize()], + coefficients(cCommitments), + pk, + ) + + // TODO this commitment is only necessary to derive the challenge, we should + // be able to avoid doing it and get the challenge in another way + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Kzg, runtime.NumCPU()*2) + if errLPoly != nil { + return nil, errLPoly + } + + // wait for foldedH and foldedHDigest + <-computeFoldedH + + // Batch open the first list of polynomials + polysQcp := coefficients(pk.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + // offset := len(polysQcp) + polysToOpen[0] = foldedH + polysToOpen[1] = linearizedPolynomialCanonical + polysToOpen[2] = bwliop.Coefficients()[:bwliop.BlindedSize()] + polysToOpen[3] = bwriop.Coefficients()[:bwriop.BlindedSize()] + polysToOpen[4] = bwoiop.Coefficients()[:bwoiop.BlindedSize()] + polysToOpen[5] = pk.trace.S1.Coefficients() + polysToOpen[6] = pk.trace.S2.Coefficients() + + digestsToOpen := make([]curve.G1Affine, len(pk.Vk.Qcp)+7) + copy(digestsToOpen[7:], pk.Vk.Qcp) + // offset = len(pk.Vk.Qcp) + digestsToOpen[0] = foldedHDigest + digestsToOpen[1] = linearizedPolynomialDigest + digestsToOpen[2] = proof.LRO[0] + digestsToOpen[3] = proof.LRO[1] + digestsToOpen[4] = proof.LRO[2] + digestsToOpen[5] = pk.Vk.S[0] + digestsToOpen[6] = pk.Vk.S[1] + + proof.BatchedProof, err = kzg.BatchOpenSinglePoint( + polysToOpen, + digestsToOpen, + zeta, + hFunc, + pk.Kzg, + ) + + log.Debug().Dur("took", time.Since(start)).Msg("prover done") + + if err != nil { + return nil, err + } + + return proof, nil + +} + +func coefficients(p []*iop.Polynomial) [][]fr.Element { + res := make([][]fr.Element, len(p)) + for i, pI := range p { + res[i] = pI.Coefficients() + } + return res +} + +// fills proof.LRO with kzg commits of bcl, bcr and bco +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() + var err0, err1, err2 error + chCommit0 := make(chan struct{}, 1) + chCommit1 := make(chan struct{}, 1) + go func() { + proof.LRO[0], err0 = kzg.Commit(bcl, kzgPk, n) + close(chCommit0) + }() + go func() { + proof.LRO[1], err1 = kzg.Commit(bcr, kzgPk, n) + close(chCommit1) + }() + if proof.LRO[2], err2 = kzg.Commit(bco, kzgPk, n); err2 != nil { + return err2 + } + <-chCommit0 + <-chCommit1 + + if err0 != nil { + return err0 + } + + return err1 +} + +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() + var err0, err1, err2 error + chCommit0 := make(chan struct{}, 1) + chCommit1 := make(chan struct{}, 1) + go func() { + proof.H[0], err0 = kzg.Commit(h1, kzgPk, n) + close(chCommit0) + }() + go func() { + proof.H[1], err1 = kzg.Commit(h2, kzgPk, n) + close(chCommit1) + }() + if proof.H[2], err2 = kzg.Commit(h3, kzgPk, n); err2 != nil { + return err2 + } + <-chCommit0 + <-chCommit1 + + if err0 != nil { + return err0 + } + + return err1 +} + +// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// The purpose is to commit and open all in one ql, qr, qm, qo, qk. +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z +// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + + // first part: individual constraints + var rl fr.Element + rl.Mul(&rZeta, &lZeta) + + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) + var s1, s2 fr.Element + chS1 := make(chan struct{}, 1) + go func() { + s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) + close(chS1) + }() + // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) + tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) + <-chS1 + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element + one.SetOne() + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) + frNbElmt.SetUint64(uint64(nbElmt)) + den.Sub(&zeta, &one). + Inverse(&den) + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := pk.trace.S3.Coefficients() + utils.Parallelize(len(blindedZCanonical), func(start, end int) { + + var t, t0, t1 fr.Element + + for i := start; i < end; i++ { + + t.Mul(&blindedZCanonical[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(s3canonical) { + + t0.Mul(&s3canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + + t.Add(&t, &t0) + } + + t.Mul(&t, &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) + + cql := pk.trace.Ql.Coefficients() + cqr := pk.trace.Qr.Coefficients() + cqm := pk.trace.Qm.Coefficients() + cqo := pk.trace.Qo.Coefficients() + cqk := pk.trace.Qk.Coefficients() + if i < len(cqm) { + + t1.Mul(&cqm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&cql[i], &lZeta) + t0.Add(&t0, &t1) + t.Add(&t, &t0) // linPol = linPol + l(ζ)*Ql(X) + + t0.Mul(&cqr[i], &rZeta) + t.Add(&t, &t0) // linPol = linPol + r(ζ)*Qr(X) + + t0.Mul(&cqo[i], &oZeta).Add(&t0, &cqk[i]) + t.Add(&t, &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) + + for j := range qcpZeta { + t0.Mul(&pi2Canonical[j][i], &qcpZeta[j]) + t.Add(&t, &t0) + } + } + + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) + blindedZCanonical[i].Add(&t, &t0) // finish the computation + } + }) + return blindedZCanonical +} diff --git a/backend/plonk/bls12-377/setup.go b/backend/plonk/bls12-377/setup.go new file mode 100644 index 0000000000..20d21a0d29 --- /dev/null +++ b/backend/plonk/bls12-377/setup.go @@ -0,0 +1,456 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package plonk + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" + "github.com/consensys/gnark/backend/plonk/internal" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/bls12-377" +) + +// Trace stores a plonk trace as columns +type Trace struct { + + // Constants describing a plonk circuit. The first entries + // of LQk (whose index correspond to the public inputs) are set to 0, and are to be + // completed by the prover. At those indices i (so from 0 to nb_public_variables), LQl[i]=-1 + // so the first nb_public_variables constraints look like this: + // -1*Wire[i] + 0* + 0 . It is zero when the constant coefficient is replaced by Wire[i]. + Ql, Qr, Qm, Qo, Qk *iop.Polynomial + Qcp []*iop.Polynomial + + // Polynomials representing the splitted permutation. The full permutation's support is 3*N where N=nb wires. + // The set of interpolation is of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls12-377/plonk/verify.go b/backend/plonk/bls12-377/verify.go similarity index 77% rename from internal/backend/bls12-377/plonk/verify.go rename to backend/plonk/bls12-377/verify.go index ca61aa05c4..c28a894d02 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/backend/plonk/bls12-377/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls12_377").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls12-377").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls12-381/plonk/verify.go b/backend/plonk/bls12-381/verify.go similarity index 77% rename from internal/backend/bls12-381/plonk/verify.go rename to backend/plonk/bls12-381/verify.go index bb1d6ebedb..3124e92c5e 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/backend/plonk/bls12-381/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls12_381").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls12-381").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls24-315/plonk/verify.go b/backend/plonk/bls24-315/verify.go similarity index 77% rename from internal/backend/bls24-315/plonk/verify.go rename to backend/plonk/bls24-315/verify.go index 9e2965275c..81dc747e0b 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/backend/plonk/bls24-315/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls24_315").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls24-315").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls24-317/plonk/verify.go b/backend/plonk/bls24-317/verify.go similarity index 77% rename from internal/backend/bls24-317/plonk/verify.go rename to backend/plonk/bls24-317/verify.go index ccd8642cb2..a6a7479e08 100644 --- a/internal/backend/bls24-317/plonk/verify.go +++ b/backend/plonk/bls24-317/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls24_317").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls24-317").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/backend/plonk/bn254/solidity.go b/backend/plonk/bn254/solidity.go new file mode 100644 index 0000000000..f3aeeb2a79 --- /dev/null +++ b/backend/plonk/bn254/solidity.go @@ -0,0 +1,1050 @@ +package plonk + +const tmplSolidityVerifier = `// SPDX-License-Identifier: Apache-2.0 + +// Copyright 2023 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +pragma solidity ^0.8.0; + +pragma experimental ABIEncoderV2; + + +library Utils { + + uint256 constant r_mod = 21888242871839275222246405745257275088548364400416034343698204186575808495617; + + /** + * @dev ExpandMsgXmd expands msg to a slice of lenInBytes bytes. + * https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5 + * https://tools.ietf.org/html/rfc8017#section-4.1 (I2OSP/O2ISP) + */ + function expand_msg(uint256 x, uint256 y) public pure returns(uint8[48] memory res){ + + string memory dst = "BSB22-Plonk"; + + //uint8[64] memory pad; // 64 is sha256 block size. + // sha256(pad || msg || (0 || 48 || 0) || dst || 11) + bytes memory tmp; + uint8 zero = 0; + uint8 lenInBytes = 48; + uint8 sizeDomain = 11; // size of dst + + for (uint i=0; i<64; i++){ + tmp = abi.encodePacked(tmp, zero); + } + tmp = abi.encodePacked(tmp, x, y, zero, lenInBytes, zero, dst, sizeDomain); + bytes32 b0 = sha256(tmp); + + tmp = abi.encodePacked(b0, uint8(1), dst, sizeDomain); + bytes32 b1 = sha256(tmp); + for (uint i=0; i<32; i++){ + res[i] = uint8(b1[i]); + } + + tmp = abi.encodePacked(uint8(b0[0]) ^ uint8(b1[0])); + for (uint i=1; i<32; i++){ + tmp = abi.encodePacked(tmp, uint8(b0[i]) ^ uint8(b1[i])); + } + + tmp = abi.encodePacked(tmp, uint8(2), dst, sizeDomain); + b1 = sha256(tmp); + + // TODO handle the size of the dst (check gnark-crypto) + for (uint i=0; i<16; i++){ + res[i+32] = uint8(b1[i]); + } + + return res; + } + + /** + * @dev cf https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 + * corresponds to https://github.com/ConsenSys/gnark-crypto/blob/develop/ecc/bn254/fr/element.go + */ + function hash_fr(uint256 x, uint256 y) internal pure returns(uint256 res) { + + // interpret a as a bigEndian integer and reduce it mod r + uint8[48] memory xmsg = expand_msg(x, y); + // uint8[48] memory xmsg = [0x44, 0x74, 0xb5, 0x29, 0xd7, 0xfb, 0x29, 0x88, 0x3a, 0x7a, 0xc1, 0x65, 0xfd, 0x72, 0xce, 0xd0, 0xd4, 0xd1, 0x3f, 0x9e, 0x85, 0x8a, 0x3, 0x86, 0x1c, 0x90, 0x83, 0x1e, 0x94, 0xdc, 0xfc, 0x1d, 0x70, 0x82, 0xf5, 0xbf, 0x30, 0x3, 0x39, 0x87, 0x21, 0x38, 0x15, 0xed, 0x12, 0x75, 0x44, 0x6a]; + + // reduce xmsg mod r, where xmsg is intrepreted in big endian + // (as SetBytes does for golang's Big.Int library). + for (uint i=0; i<32; i++){ + res += uint256(xmsg[47-i])<<(8*i); + } + res = res % r_mod; + uint256 tmp; + for (uint i=0; i<16; i++){ + tmp += uint256(xmsg[15-i])<<(8*i); + } + + // 2**256%r + uint256 b = 6350874878119819312338956282401532410528162663560392320966563075034087161851; + assembly { + tmp := mulmod(tmp, b, r_mod) + res := addmod(res, tmp, r_mod) + } + + return res; + } + +} + +contract PlonkVerifier { + + using Utils for *; + uint256 constant r_mod = 21888242871839275222246405745257275088548364400416034343698204186575808495617; + uint256 constant p_mod = 21888242871839275222246405745257275088696311157297823662689037894645226208583; + {{ range $index, $element := .Kzg.G2 }} + uint256 constant g2_srs_{{ $index }}_x_0 = {{ (fpstr $element.X.A1) }}; + uint256 constant g2_srs_{{ $index }}_x_1 = {{ (fpstr $element.X.A0) }}; + uint256 constant g2_srs_{{ $index }}_y_0 = {{ (fpstr $element.Y.A1) }}; + uint256 constant g2_srs_{{ $index }}_y_1 = {{ (fpstr $element.Y.A0) }}; + {{ end }} + // ----------------------- vk --------------------- + uint256 constant vk_domain_size = {{ .Size }}; + uint256 constant vk_inv_domain_size = {{ (frstr .SizeInv) }}; + uint256 constant vk_omega = {{ (frstr .Generator) }}; + uint256 constant vk_ql_com_x = {{ (fpstr .Ql.X) }}; + uint256 constant vk_ql_com_y = {{ (fpstr .Ql.Y) }}; + uint256 constant vk_qr_com_x = {{ (fpstr .Qr.X) }}; + uint256 constant vk_qr_com_y = {{ (fpstr .Qr.Y) }}; + uint256 constant vk_qm_com_x = {{ (fpstr .Qm.X) }}; + uint256 constant vk_qm_com_y = {{ (fpstr .Qm.Y) }}; + uint256 constant vk_qo_com_x = {{ (fpstr .Qo.X) }}; + uint256 constant vk_qo_com_y = {{ (fpstr .Qo.Y) }}; + uint256 constant vk_qk_com_x = {{ (fpstr .Qk.X) }}; + uint256 constant vk_qk_com_y = {{ (fpstr .Qk.Y) }}; + {{ range $index, $element := .S }} + uint256 constant vk_s{{ inc $index }}_com_x = {{ (fpstr $element.X) }}; + uint256 constant vk_s{{ inc $index }}_com_y = {{ (fpstr $element.Y) }}; + {{ end }} + uint256 constant vk_coset_shift = 5; + + {{ range $index, $element := .Qcp}} + uint256 constant vk_selector_commitments_commit_api_{{ $index }}_x = {{ (fpstr $element.X) }}; + uint256 constant vk_selector_commitments_commit_api_{{ $index }}_y = {{ (fpstr $element.Y) }}; + {{ end }} + + {{ if (gt (len .CommitmentConstraintIndexes) 0 )}} + function load_vk_commitments_indices_commit_api(uint256[] memory v) + internal view { + assembly { + let _v := add(v, 0x20) + {{ range .CommitmentConstraintIndexes }} + mstore(_v, {{ . }}) + _v := add(_v, 0x20) + {{ end }} + } + } + {{ end }} + uint256 constant vk_nb_commitments_commit_api = {{ len .CommitmentConstraintIndexes }}; + + // ------------------------------------------------ + + // offset proof + uint256 constant proof_l_com_x = 0x20; + uint256 constant proof_l_com_y = 0x40; + uint256 constant proof_r_com_x = 0x60; + uint256 constant proof_r_com_y = 0x80; + uint256 constant proof_o_com_x = 0xa0; + uint256 constant proof_o_com_y = 0xc0; + + // h = h_0 + x^{n+2}h_1 + x^{2(n+2)}h_2 + uint256 constant proof_h_0_x = 0xe0; + uint256 constant proof_h_0_y = 0x100; + uint256 constant proof_h_1_x = 0x120; + uint256 constant proof_h_1_y = 0x140; + uint256 constant proof_h_2_x = 0x160; + uint256 constant proof_h_2_y = 0x180; + + // wire values at zeta + uint256 constant proof_l_at_zeta = 0x1a0; + uint256 constant proof_r_at_zeta = 0x1c0; + uint256 constant proof_o_at_zeta = 0x1e0; + + //uint256[STATE_WIDTH-1] permutation_polynomials_at_zeta; // Sσ1(zeta),Sσ2(zeta) + uint256 constant proof_s1_at_zeta = 0x200; // Sσ1(zeta) + uint256 constant proof_s2_at_zeta = 0x220; // Sσ2(zeta) + + //Bn254.G1Point grand_product_commitment; // [z(x)] + uint256 constant proof_grand_product_commitment_x = 0x240; + uint256 constant proof_grand_product_commitment_y = 0x260; + + uint256 constant proof_grand_product_at_zeta_omega = 0x280; // z(w*zeta) + uint256 constant proof_quotient_polynomial_at_zeta = 0x2a0; // t(zeta) + uint256 constant proof_linearised_polynomial_at_zeta = 0x2c0; // r(zeta) + + // Folded proof for the opening of H, linearised poly, l, r, o, s_1, s_2, qcp + uint256 constant proof_batch_opening_at_zeta_x = 0x2e0; // [Wzeta] + uint256 constant proof_batch_opening_at_zeta_y = 0x300; + + //Bn254.G1Point opening_at_zeta_omega_proof; // [Wzeta*omega] + uint256 constant proof_opening_at_zeta_omega_x = 0x320; + uint256 constant proof_opening_at_zeta_omega_y = 0x340; + + uint256 constant proof_openings_selector_commit_api_at_zeta = 0x360; + // -> next part of proof is + // [ openings_selector_commits || commitments_wires_commit_api] + + // -------- offset state + + // challenges to check the claimed quotient + uint256 constant state_alpha = 0x00; + uint256 constant state_beta = 0x20; + uint256 constant state_gamma = 0x40; + uint256 constant state_zeta = 0x60; + + // challenges related to KZG + uint256 constant state_sv = 0x80; + uint256 constant state_su = 0xa0; + + // reusable value + uint256 constant state_alpha_square_lagrange = 0xc0; + + // commitment to H + // Bn254.G1Point folded_h; + uint256 constant state_folded_h_x = 0xe0; + uint256 constant state_folded_h_y = 0x100; + + // commitment to the linearised polynomial + uint256 constant state_linearised_polynomial_x = 0x120; + uint256 constant state_linearised_polynomial_y = 0x140; + + // Folded proof for the opening of H, linearised poly, l, r, o, s_1, s_2, qcp + // Kzg.OpeningProof folded_proof; + uint256 constant state_folded_claimed_values = 0x160; + + // folded digests of H, linearised poly, l, r, o, s_1, s_2, qcp + // Bn254.G1Point folded_digests; + uint256 constant state_folded_digests_x = 0x180; + uint256 constant state_folded_digests_y = 0x1a0; + + uint256 constant state_pi = 0x1c0; + + uint256 constant state_zeta_power_n_minus_one = 0x1e0; + uint256 constant state_alpha_square_lagrange_one = 0x200; + + uint256 constant state_gamma_kzg = 0x220; + + uint256 constant state_success = 0x240; + uint256 constant state_check_var = 0x260; // /!\ this slot is used for debugging only + + + uint256 constant state_last_mem = 0x280; + + event PrintUint256(uint256 a); + + function derive_gamma_beta_alpha_zeta(bytes memory proof, uint256[] memory public_inputs) + internal view returns(uint256, uint256, uint256, uint256) { + + uint256 gamma; + uint256 beta; + uint256 alpha; + uint256 zeta; + + assembly { + + let mem := mload(0x40) + + derive_gamma(proof, public_inputs) + gamma := mload(mem) + + derive_beta(proof, gamma) + beta := mload(mem) + + derive_alpha(proof, beta) + alpha := mload(mem) + + derive_zeta(proof, alpha) + zeta := mload(mem) + + gamma := mod(gamma, r_mod) + beta := mod(beta, r_mod) + alpha := mod(alpha, r_mod) + zeta := mod(zeta, r_mod) + + function derive_gamma(aproof, pub_inputs) { + + let mPtr := mload(0x40) + + // gamma + mstore(mPtr, 0x67616d6d61) // "gamma" + + mstore(add(mPtr, 0x20), vk_s1_com_x) + mstore(add(mPtr, 0x40), vk_s1_com_y) + mstore(add(mPtr, 0x60), vk_s2_com_x) + mstore(add(mPtr, 0x80), vk_s2_com_y) + mstore(add(mPtr, 0xa0), vk_s3_com_x) + mstore(add(mPtr, 0xc0), vk_s3_com_y) + mstore(add(mPtr, 0xe0), vk_ql_com_x) + mstore(add(mPtr, 0x100), vk_ql_com_y) + mstore(add(mPtr, 0x120), vk_qr_com_x) + mstore(add(mPtr, 0x140), vk_qr_com_y) + mstore(add(mPtr, 0x160), vk_qm_com_x) + mstore(add(mPtr, 0x180), vk_qm_com_y) + mstore(add(mPtr, 0x1a0), vk_qo_com_x) + mstore(add(mPtr, 0x1c0), vk_qo_com_y) + mstore(add(mPtr, 0x1e0), vk_qk_com_x) + mstore(add(mPtr, 0x200), vk_qk_com_y) + + let pi := add(pub_inputs, 0x20) + let _mPtr := add(mPtr, 0x220) + for {let i:=0} lt(i, mload(pub_inputs)) {i:=add(i,1)} + { + mstore(_mPtr, mload(pi)) + pi := add(pi, 0x20) + _mPtr := add(_mPtr, 0x20) + } + + let _proof := add(aproof, proof_openings_selector_commit_api_at_zeta) + _proof := add(_proof, mul(vk_nb_commitments_commit_api, 0x20)) + for {let i:=0} lt(i, vk_nb_commitments_commit_api) {i:=add(i,1)} + { + mstore(_mPtr, mload(_proof)) + mstore(add(_mPtr, 0x20), mload(add(_proof, 0x20))) + _mPtr := add(_mPtr, 0x40) + _proof := add(_proof, 0x40) + } + // pop(staticcall(sub(gas(), 2000), 0x2, add(mPtr, 0x1b), 0x2a5, mPtr, 0x20)) //0x1b -> 000.."gamma" + + mstore(_mPtr, mload(add(aproof, proof_l_com_x))) + mstore(add(_mPtr, 0x20), mload(add(aproof, proof_l_com_y))) + mstore(add(_mPtr, 0x40), mload(add(aproof, proof_r_com_x))) + mstore(add(_mPtr, 0x60), mload(add(aproof, proof_r_com_y))) + mstore(add(_mPtr, 0x80), mload(add(aproof, proof_o_com_x))) + mstore(add(_mPtr, 0xa0), mload(add(aproof, proof_o_com_y))) + // pop(staticcall(sub(gas(), 2000), 0x2, add(mPtr, 0x1b), 0x365, mPtr, 0x20)) //0x1b -> 000.."gamma" + + let size := add(0x2c5, mul(mload(pub_inputs), 0x20)) // 0x2c5 = 22*32+5 + size := add(size, mul(vk_nb_commitments_commit_api, 0x40)) + pop(staticcall(sub(gas(), 2000), 0x2, add(mPtr, 0x1b), size, mPtr, 0x20)) //0x1b -> 000.."gamma" + } + + function derive_beta(aproof, prev_challenge){ + let mPtr := mload(0x40) + // beta + mstore(mPtr, 0x62657461) // "beta" + mstore(add(mPtr, 0x20), prev_challenge) + pop(staticcall(sub(gas(), 2000), 0x2, add(mPtr, 0x1c), 0x24, mPtr, 0x20)) //0x1b -> 000.."gamma" + } + + function derive_alpha(aproof, prev_challenge){ + let mPtr := mload(0x40) + // alpha + mstore(mPtr, 0x616C706861) // "alpha" + mstore(add(mPtr, 0x20), prev_challenge) + mstore(add(mPtr, 0x40), mload(add(aproof, proof_grand_product_commitment_x))) + mstore(add(mPtr, 0x60), mload(add(aproof, proof_grand_product_commitment_y))) + pop(staticcall(sub(gas(), 2000), 0x2, add(mPtr, 0x1b), 0x65, mPtr, 0x20)) //0x1b -> 000.."gamma" + } + + function derive_zeta(aproof, prev_challenge) { + let mPtr := mload(0x40) + // zeta + mstore(mPtr, 0x7a657461) // "zeta" + mstore(add(mPtr, 0x20), prev_challenge) + mstore(add(mPtr, 0x40), mload(add(aproof, proof_h_0_x))) + mstore(add(mPtr, 0x60), mload(add(aproof, proof_h_0_y))) + mstore(add(mPtr, 0x80), mload(add(aproof, proof_h_1_x))) + mstore(add(mPtr, 0xa0), mload(add(aproof, proof_h_1_y))) + mstore(add(mPtr, 0xc0), mload(add(aproof, proof_h_2_x))) + mstore(add(mPtr, 0xe0), mload(add(aproof, proof_h_2_y))) + pop(staticcall(sub(gas(), 2000), 0x2, add(mPtr, 0x1c), 0xe4, mPtr, 0x20)) + } + } + + return (gamma, beta, alpha, zeta); + } + + function load_wire_commitments_commit_api(uint256[] memory wire_commitments, bytes memory proof) + pure internal { + assembly { + let w := add(wire_commitments, 0x20) + let p := add(proof, proof_openings_selector_commit_api_at_zeta) + p := add(p, mul(vk_nb_commitments_commit_api, 0x20)) + for {let i:=0} lt(i, mul(vk_nb_commitments_commit_api,2)) {i:=add(i,1)} + { + mstore(w, mload(p)) + w := add(w,0x20) + p := add(p,0x20) + mstore(w, mload(p)) + w := add(w,0x20) + p := add(p,0x20) + } + } + } + + function compute_ith_lagrange_at_z(uint256 zeta, uint256 i) + internal view returns (uint256) { + + uint256 res; + assembly { + + // _n^_i [r] + function pow_local(x, e)->result { + let mPtr := mload(0x40) + mstore(mPtr, 0x20) + mstore(add(mPtr, 0x20), 0x20) + mstore(add(mPtr, 0x40), 0x20) + mstore(add(mPtr, 0x60), x) + mstore(add(mPtr, 0x80), e) + mstore(add(mPtr, 0xa0), r_mod) + pop(staticcall(sub(gas(), 2000),0x05,mPtr,0xc0,0x00,0x20)) + result := mload(0x00) + } + + let w := pow_local(vk_omega,i) // w**i + i := addmod(zeta, sub(r_mod, w), r_mod) // z-w**i + zeta := pow_local(zeta, vk_domain_size) // z**n + zeta := addmod(zeta, sub(r_mod, 1), r_mod) // z**n-1 + w := mulmod(w, vk_inv_domain_size, r_mod) // w**i/n + i := pow_local(i, sub(r_mod,2)) // (z-w**i)**-1 + w := mulmod(w, i, r_mod) // w**i/n*(z-w)**-1 + res := mulmod(w, zeta, r_mod) + } + + return res; + } + + function compute_pi( + bytes memory proof, + uint256[] memory public_inputs, + uint256 zeta + ) internal view returns (uint256) { + + // evaluation of Z=Xⁿ⁻¹ at ζ + // uint256 zeta_power_n_minus_one = Fr.pow(zeta, vk_domain_size); + // zeta_power_n_minus_one = Fr.sub(zeta_power_n_minus_one, 1); + uint256 zeta_power_n_minus_one; + + uint256 pi; + + assembly { + + sum_pi_wo_api_commit(add(public_inputs,0x20), mload(public_inputs), zeta) + pi := mload(mload(0x40)) + + function sum_pi_wo_api_commit(ins, n, z) { + let li := mload(0x40) + batch_compute_lagranges_at_z(z, n, li) + let res := 0 + let tmp := 0 + for {let i:=0} lt(i,n) {i:=add(i,1)} + { + tmp := mulmod(mload(li), mload(ins), r_mod) + res := addmod(res, tmp, r_mod) + li := add(li, 0x20) + ins := add(ins, 0x20) + } + mstore(mload(0x40), res) + } + + // mPtr <- [L_0(z), .., L_{n-1}(z)] + function batch_compute_lagranges_at_z(z, n, mPtr) { + let zn := addmod(pow(z, vk_domain_size, mPtr), sub(r_mod, 1), r_mod) + zn := mulmod(zn, vk_inv_domain_size, r_mod) + let _w := 1 + let _mPtr := mPtr + for {let i:=0} lt(i,n) {i:=add(i,1)} + { + mstore(_mPtr, addmod(z,sub(r_mod, _w), r_mod)) + _w := mulmod(_w, vk_omega, r_mod) + _mPtr := add(_mPtr, 0x20) + } + batch_invert(mPtr, n, _mPtr) + _mPtr := mPtr + _w := 1 + for {let i:=0} lt(i,n) {i:=add(i,1)} + { + mstore(_mPtr, mulmod(mulmod(mload(_mPtr), zn , r_mod), _w, r_mod)) + _mPtr := add(_mPtr, 0x20) + _w := mulmod(_w, vk_omega, r_mod) + } + } + + function batch_invert(ins, nb_ins, mPtr) { + mstore(mPtr, 1) + let offset := 0 + for {let i:=0} lt(i, nb_ins) {i:=add(i,1)} + { + let prev := mload(add(mPtr, offset)) + let cur := mload(add(ins, offset)) + cur := mulmod(prev, cur, r_mod) + offset := add(offset, 0x20) + mstore(add(mPtr, offset), cur) + } + ins := add(ins, sub(offset, 0x20)) + mPtr := add(mPtr, offset) + let inv := pow(mload(mPtr), sub(r_mod,2), add(mPtr, 0x20)) + for {let i:=0} lt(i, nb_ins) {i:=add(i,1)} + { + mPtr := sub(mPtr, 0x20) + let tmp := mload(ins) + let cur := mulmod(inv, mload(mPtr), r_mod) + mstore(ins, cur) + inv := mulmod(inv, tmp, r_mod) + ins := sub(ins, 0x20) + } + } + + // res <- x^e mod r + function pow(x, e, mPtr)->res { + mstore(mPtr, 0x20) + mstore(add(mPtr, 0x20), 0x20) + mstore(add(mPtr, 0x40), 0x20) + mstore(add(mPtr, 0x60), x) + mstore(add(mPtr, 0x80), e) + mstore(add(mPtr, 0xa0), r_mod) + pop(staticcall(sub(gas(), 2000),0x05,mPtr,0xc0,mPtr,0x20)) + res := mload(mPtr) + } + + zeta_power_n_minus_one := pow(zeta, vk_domain_size, mload(0x40)) + zeta_power_n_minus_one := addmod(zeta_power_n_minus_one, sub(r_mod, 1), r_mod) + } + + {{ if (gt (len .CommitmentConstraintIndexes) 0 )}} + uint256[] memory commitment_indices = new uint256[](vk_nb_commitments_commit_api); + load_vk_commitments_indices_commit_api(commitment_indices); + + uint256[] memory wire_committed_commitments; + wire_committed_commitments = new uint256[](2*vk_nb_commitments_commit_api); + + load_wire_commitments_commit_api(wire_committed_commitments, proof); + + for (uint256 i=0; ires { + mstore(mPtr, 0x20) + mstore(add(mPtr, 0x20), 0x20) + mstore(add(mPtr, 0x40), 0x20) + mstore(add(mPtr, 0x60), x) + mstore(add(mPtr, 0x80), e) + mstore(add(mPtr, 0xa0), r_mod) + pop(staticcall(sub(gas(), 2000),0x05,mPtr,0xc0,mPtr,0x20)) + res := mload(mPtr) + } + } + + return success; + + } + +} +` + +// MarshalSolidity converts a proof to a byte array that can be used in a +// Solidity contract. +func (proof *Proof) MarshalSolidity() []byte { + + res := make([]byte, 0, 1024) + + // uint256 l_com_x; + // uint256 l_com_y; + // uint256 r_com_x; + // uint256 r_com_y; + // uint256 o_com_x; + // uint256 o_com_y; + var tmp64 [64]byte + for i := 0; i < 3; i++ { + tmp64 = proof.LRO[i].RawBytes() + res = append(res, tmp64[:]...) + } + + // uint256 h_0_x; + // uint256 h_0_y; + // uint256 h_1_x; + // uint256 h_1_y; + // uint256 h_2_x; + // uint256 h_2_y; + for i := 0; i < 3; i++ { + tmp64 = proof.H[i].RawBytes() + res = append(res, tmp64[:]...) + } + var tmp32 [32]byte + + // uint256 l_at_zeta; + // uint256 r_at_zeta; + // uint256 o_at_zeta; + // uint256 s1_at_zeta; + // uint256 s2_at_zeta; + for i := 2; i < 7; i++ { + tmp32 = proof.BatchedProof.ClaimedValues[i].Bytes() + res = append(res, tmp32[:]...) + } + + // uint256 grand_product_commitment_x; + // uint256 grand_product_commitment_y; + tmp64 = proof.Z.RawBytes() + res = append(res, tmp64[:]...) + + // uint256 grand_product_at_zeta_omega; + tmp32 = proof.ZShiftedOpening.ClaimedValue.Bytes() + res = append(res, tmp32[:]...) + + // uint256 quotient_polynomial_at_zeta; + // uint256 linearization_polynomial_at_zeta; + tmp32 = proof.BatchedProof.ClaimedValues[0].Bytes() + res = append(res, tmp32[:]...) + tmp32 = proof.BatchedProof.ClaimedValues[1].Bytes() + res = append(res, tmp32[:]...) + + // uint256 opening_at_zeta_proof_x; + // uint256 opening_at_zeta_proof_y; + tmp64 = proof.BatchedProof.H.RawBytes() + res = append(res, tmp64[:]...) + + // uint256 opening_at_zeta_omega_proof_x; + // uint256 opening_at_zeta_omega_proof_y; + tmp64 = proof.ZShiftedOpening.H.RawBytes() + res = append(res, tmp64[:]...) + + // uint256[] selector_commit_api_at_zeta; + // uint256[] wire_committed_commitments; + if len(proof.Bsb22Commitments) > 0 { + tmp32 = proof.BatchedProof.ClaimedValues[7].Bytes() + res = append(res, tmp32[:]...) + + for _, bc := range proof.Bsb22Commitments { + tmp64 = bc.RawBytes() + res = append(res, tmp64[:]...) + } + } + + return res +} diff --git a/internal/backend/bn254/plonk/verify.go b/backend/plonk/bn254/verify.go similarity index 73% rename from internal/backend/bn254/plonk/verify.go rename to backend/plonk/bn254/verify.go index 74b56c5caf..88bd903f5e 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/backend/plonk/bn254/verify.go @@ -25,6 +25,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" curve "github.com/consensys/gnark-crypto/ecc/bn254" @@ -53,7 +55,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -87,25 +89,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bw6-633/plonk/verify.go b/backend/plonk/bw6-633/verify.go similarity index 77% rename from internal/backend/bw6-633/plonk/verify.go rename to backend/plonk/bw6-633/verify.go index 4d8d69d051..63c4b1dbfb 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/backend/plonk/bw6-633/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bw6_633").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bw6-633").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bw6-761/plonk/verify.go b/backend/plonk/bw6-761/verify.go similarity index 77% rename from internal/backend/bw6-761/plonk/verify.go rename to backend/plonk/bw6-761/verify.go index 9a1dd7650f..51c0719b8d 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/backend/plonk/bw6-761/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bw6_761").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bw6-761").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls12-377/plonkfri/verify.go b/backend/plonkfri/bls12-377/verify.go similarity index 100% rename from internal/backend/bls12-377/plonkfri/verify.go rename to backend/plonkfri/bls12-377/verify.go diff --git a/internal/backend/bls12-381/plonkfri/prove.go b/backend/plonkfri/bls12-381/prove.go similarity index 92% rename from internal/backend/bls12-381/plonkfri/prove.go rename to backend/plonkfri/bls12-381/prove.go index 8909988986..3823ed3b26 100644 --- a/internal/backend/bls12-381/plonkfri/prove.go +++ b/backend/plonkfri/bls12-381/prove.go @@ -22,6 +22,8 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bls12-381/plonkfri/setup.go b/backend/plonkfri/bls12-381/setup.go similarity index 91% rename from internal/backend/bls12-381/plonkfri/setup.go rename to backend/plonkfri/bls12-381/setup.go index da30df7b0f..4f2b6e1b20 100644 --- a/internal/backend/bls12-381/plonkfri/setup.go +++ b/backend/plonkfri/bls12-381/setup.go @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls12-381/plonkfri/verify.go b/backend/plonkfri/bls12-381/verify.go similarity index 100% rename from internal/backend/bls12-381/plonkfri/verify.go rename to backend/plonkfri/bls12-381/verify.go diff --git a/internal/backend/bls24-315/plonkfri/prove.go b/backend/plonkfri/bls24-315/prove.go similarity index 92% rename from internal/backend/bls24-315/plonkfri/prove.go rename to backend/plonkfri/bls24-315/prove.go index a440c5953f..5d7f385ea1 100644 --- a/internal/backend/bls24-315/plonkfri/prove.go +++ b/backend/plonkfri/bls24-315/prove.go @@ -22,6 +22,8 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bls24-315/plonkfri/setup.go b/backend/plonkfri/bls24-315/setup.go similarity index 91% rename from internal/backend/bls24-315/plonkfri/setup.go rename to backend/plonkfri/bls24-315/setup.go index d02f7f504f..9063dc95d9 100644 --- a/internal/backend/bls24-315/plonkfri/setup.go +++ b/backend/plonkfri/bls24-315/setup.go @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls24-315/plonkfri/verify.go b/backend/plonkfri/bls24-315/verify.go similarity index 100% rename from internal/backend/bls24-315/plonkfri/verify.go rename to backend/plonkfri/bls24-315/verify.go diff --git a/internal/backend/bls24-317/plonkfri/prove.go b/backend/plonkfri/bls24-317/prove.go similarity index 92% rename from internal/backend/bls24-317/plonkfri/prove.go rename to backend/plonkfri/bls24-317/prove.go index 752c825069..409d66b34e 100644 --- a/internal/backend/bls24-317/plonkfri/prove.go +++ b/backend/plonkfri/bls24-317/prove.go @@ -22,6 +22,8 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bls24-317/plonkfri/setup.go b/backend/plonkfri/bls24-317/setup.go similarity index 91% rename from internal/backend/bls24-317/plonkfri/setup.go rename to backend/plonkfri/bls24-317/setup.go index 7cd1fb80ff..c2ab7ee5be 100644 --- a/internal/backend/bls24-317/plonkfri/setup.go +++ b/backend/plonkfri/bls24-317/setup.go @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls24-317/plonkfri/verify.go b/backend/plonkfri/bls24-317/verify.go similarity index 100% rename from internal/backend/bls24-317/plonkfri/verify.go rename to backend/plonkfri/bls24-317/verify.go diff --git a/internal/backend/bn254/plonkfri/prove.go b/backend/plonkfri/bn254/prove.go similarity index 92% rename from internal/backend/bn254/plonkfri/prove.go rename to backend/plonkfri/bn254/prove.go index 1166b5cc30..1cb0e6fbc6 100644 --- a/internal/backend/bn254/plonkfri/prove.go +++ b/backend/plonkfri/bn254/prove.go @@ -22,6 +22,8 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bn254/plonkfri/setup.go b/backend/plonkfri/bn254/setup.go similarity index 91% rename from internal/backend/bn254/plonkfri/setup.go rename to backend/plonkfri/bn254/setup.go index 4a47a70839..6b37f37600 100644 --- a/internal/backend/bn254/plonkfri/setup.go +++ b/backend/plonkfri/bn254/setup.go @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bn254/plonkfri/verify.go b/backend/plonkfri/bn254/verify.go similarity index 100% rename from internal/backend/bn254/plonkfri/verify.go rename to backend/plonkfri/bn254/verify.go diff --git a/internal/backend/bw6-633/plonkfri/prove.go b/backend/plonkfri/bw6-633/prove.go similarity index 92% rename from internal/backend/bw6-633/plonkfri/prove.go rename to backend/plonkfri/bw6-633/prove.go index aa02bc934b..3e9d8e8ef4 100644 --- a/internal/backend/bw6-633/plonkfri/prove.go +++ b/backend/plonkfri/bw6-633/prove.go @@ -22,6 +22,8 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bw6-633/plonkfri/setup.go b/backend/plonkfri/bw6-633/setup.go similarity index 91% rename from internal/backend/bw6-633/plonkfri/setup.go rename to backend/plonkfri/bw6-633/setup.go index 04d0062321..dff84dbee9 100644 --- a/internal/backend/bw6-633/plonkfri/setup.go +++ b/backend/plonkfri/bw6-633/setup.go @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bw6-633/plonkfri/verify.go b/backend/plonkfri/bw6-633/verify.go similarity index 100% rename from internal/backend/bw6-633/plonkfri/verify.go rename to backend/plonkfri/bw6-633/verify.go diff --git a/internal/backend/bw6-761/plonkfri/prove.go b/backend/plonkfri/bw6-761/prove.go similarity index 92% rename from internal/backend/bw6-761/plonkfri/prove.go rename to backend/plonkfri/bw6-761/prove.go index e5f9cb5b78..f57bc4f6d3 100644 --- a/internal/backend/bw6-761/plonkfri/prove.go +++ b/backend/plonkfri/bw6-761/prove.go @@ -22,6 +22,8 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bw6-761/plonkfri/setup.go b/backend/plonkfri/bw6-761/setup.go similarity index 91% rename from internal/backend/bw6-761/plonkfri/setup.go rename to backend/plonkfri/bw6-761/setup.go index 32be4d26fe..ee278a5eda 100644 --- a/internal/backend/bw6-761/plonkfri/setup.go +++ b/backend/plonkfri/bw6-761/setup.go @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bw6-761/plonkfri/verify.go b/backend/plonkfri/bw6-761/verify.go similarity index 100% rename from internal/backend/bw6-761/plonkfri/verify.go rename to backend/plonkfri/bw6-761/verify.go diff --git a/backend/plonkfri/plonkfri.go b/backend/plonkfri/plonkfri.go index b2cab82fd6..5fcb3760eb 100644 --- a/backend/plonkfri/plonkfri.go +++ b/backend/plonkfri/plonkfri.go @@ -29,12 +29,13 @@ import ( cs_bw6633 "github.com/consensys/gnark/constraint/bw6-633" cs_bw6761 "github.com/consensys/gnark/constraint/bw6-761" - plonk_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/plonkfri" - plonk_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/plonkfri" - plonk_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/plonkfri" - plonk_bn254 "github.com/consensys/gnark/internal/backend/bn254/plonkfri" - plonk_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/plonkfri" - plonk_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/plonkfri" + plonk_bls12377 "github.com/consensys/gnark/backend/plonkfri/bls12-377" + plonk_bls12381 "github.com/consensys/gnark/backend/plonkfri/bls12-381" + plonk_bls24315 "github.com/consensys/gnark/backend/plonkfri/bls24-315" + plonk_bls24317 "github.com/consensys/gnark/backend/plonkfri/bls24-317" + plonk_bn254 "github.com/consensys/gnark/backend/plonkfri/bn254" + plonk_bw6633 "github.com/consensys/gnark/backend/plonkfri/bw6-633" + plonk_bw6761 "github.com/consensys/gnark/backend/plonkfri/bw6-761" fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -43,8 +44,6 @@ import ( fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - plonk_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/plonkfri" ) // Proof represents a Plonk proof generated by plonk.Prove @@ -106,60 +105,28 @@ func Setup(ccs constraint.ConstraintSystem) (ProvingKey, VerifyingKey, error) { // internally, the solution vector to the SparseR1CS will be filled with random values which may impact benchmarking func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { - // apply options - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return nil, err - } - switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bn254.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, opt) + return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), fullWitness, opts...) case *cs_bls12381.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls12381.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, opt) + return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), fullWitness, opts...) case *cs_bls12377.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls12377.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, opt) + return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), fullWitness, opts...) case *cs_bw6761.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bw6761.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, opt) + return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), fullWitness, opts...) case *cs_bw6633.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bw6633.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), w, opt) + return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), fullWitness, opts...) case *cs_bls24315.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls24315.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, opt) + return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), fullWitness, opts...) + case *cs_bls24317.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls24317.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), w, opt) + return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), fullWitness, opts...) + default: panic("unrecognized SparseR1CS curve type") } diff --git a/backend/witness/witness.go b/backend/witness/witness.go index 437ffcd35c..2db96aede3 100644 --- a/backend/witness/witness.go +++ b/backend/witness/witness.go @@ -113,7 +113,7 @@ func New(field *big.Int) (Witness, error) { } func (w *witness) Fill(nbPublic, nbSecret int, values <-chan any) error { - n := int(nbPublic + nbSecret) + n := nbPublic + nbSecret w.vector = resize(w.vector, n) w.nbPublic = uint32(nbPublic) w.nbSecret = uint32(nbSecret) diff --git a/backend/witness/witness_test.go b/backend/witness/witness_test.go index ba127b3240..40446abe81 100644 --- a/backend/witness/witness_test.go +++ b/backend/witness/witness_test.go @@ -1,6 +1,8 @@ package witness_test import ( + "bytes" + "encoding/json" "fmt" "reflect" "testing" @@ -47,11 +49,17 @@ func ExampleWitness() { // first get the circuit expected schema schema, _ := frontend.NewSchema(assignment) - json, _ := reconstructed.ToJSON(schema) + ret, _ := reconstructed.ToJSON(schema) - fmt.Println(string(json)) + var b bytes.Buffer + json.Indent(&b, ret, "", "\t") + fmt.Println(b.String()) // Output: - // {"X":42,"Y":8000,"E":1} + // { + // "X": 42, + // "Y": 8000, + // "E": 1 + // } } @@ -148,3 +156,34 @@ func roundTripMarshalJSON(assert *require.Assertions, assignment circuit, public assert.True(reflect.DeepEqual(rw, w), "witness json round trip serialization") } + +type initableVariable struct { + Val []frontend.Variable +} + +func (iv *initableVariable) GnarkInitHook() { + if iv.Val == nil { + iv.Val = []frontend.Variable{1, 2} // need to init value as are assigning to witness + } +} + +type initableCircuit struct { + X [2]initableVariable + Y []initableVariable + Z initableVariable +} + +func (c *initableCircuit) Define(api frontend.API) error { + panic("not called") +} + +func TestVariableInitHook(t *testing.T) { + assert := require.New(t) + + assignment := &initableCircuit{Y: make([]initableVariable, 2)} + w, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField()) + assert.NoError(err) + fw, ok := w.Vector().(fr.Vector) + assert.True(ok) + assert.Len(fw, 10, "invalid length") +} diff --git a/constraint/bls12-377/coeff.go b/constraint/bls12-377/coeff.go index cc8036f26f..a1b86b75c4 100644 --- a/constraint/bls12-377/coeff.go +++ b/constraint/bls12-377/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls12-377/r1cs.go b/constraint/bls12-377/r1cs.go deleted file mode 100644 index 5170805f22..0000000000 --- a/constraint/bls12-377/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS12_377 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-377/r1cs_sparse.go b/constraint/bls12-377/r1cs_sparse.go deleted file mode 100644 index ed01b36846..0000000000 --- a/constraint/bls12-377/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS12-377) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS12_377 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-377/r1cs_test.go b/constraint/bls12-377/r1cs_test.go index 926f9c1bb6..b3eb222420 100644 --- a/constraint/bls12-377/r1cs_test.go +++ b/constraint/bls12-377/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls12-377/solution.go b/constraint/bls12-377/solution.go deleted file mode 100644 index 8849a30acb..0000000000 --- a/constraint/bls12-377/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls12-377/solver.go b/constraint/bls12-377/solver.go new file mode 100644 index 0000000000..c992d55b4a --- /dev/null +++ b/constraint/bls12-377/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls12-377/system.go b/constraint/bls12-377/system.go new file mode 100644 index 0000000000..5a35f10a00 --- /dev/null +++ b/constraint/bls12-377/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS12_377 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/bls12-381/coeff.go b/constraint/bls12-381/coeff.go index f0ea2ac85f..9eb3786806 100644 --- a/constraint/bls12-381/coeff.go +++ b/constraint/bls12-381/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls12-381/r1cs.go b/constraint/bls12-381/r1cs.go deleted file mode 100644 index e45d4fff8f..0000000000 --- a/constraint/bls12-381/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS12_381 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-381/r1cs_sparse.go b/constraint/bls12-381/r1cs_sparse.go deleted file mode 100644 index 393287d780..0000000000 --- a/constraint/bls12-381/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS12-381) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS12_381 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-381/r1cs_test.go b/constraint/bls12-381/r1cs_test.go index 90c9e1e849..c7b0fa6156 100644 --- a/constraint/bls12-381/r1cs_test.go +++ b/constraint/bls12-381/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls12-381/solution.go b/constraint/bls12-381/solution.go deleted file mode 100644 index dbf96fadc4..0000000000 --- a/constraint/bls12-381/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls12-381/solver.go b/constraint/bls12-381/solver.go new file mode 100644 index 0000000000..125ce56e8e --- /dev/null +++ b/constraint/bls12-381/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls12-381/system.go b/constraint/bls12-381/system.go new file mode 100644 index 0000000000..ff3c074f67 --- /dev/null +++ b/constraint/bls12-381/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS12_381 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/bls24-315/coeff.go b/constraint/bls24-315/coeff.go index d91528f93e..084652f545 100644 --- a/constraint/bls24-315/coeff.go +++ b/constraint/bls24-315/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls24-315/r1cs.go b/constraint/bls24-315/r1cs.go deleted file mode 100644 index 39f1d832b1..0000000000 --- a/constraint/bls24-315/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS24_315 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-315/r1cs_sparse.go b/constraint/bls24-315/r1cs_sparse.go deleted file mode 100644 index 69e3d01d03..0000000000 --- a/constraint/bls24-315/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS24-315) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS24_315 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-315/r1cs_test.go b/constraint/bls24-315/r1cs_test.go index a5cf07e4a3..3885c438b5 100644 --- a/constraint/bls24-315/r1cs_test.go +++ b/constraint/bls24-315/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls24-315/solution.go b/constraint/bls24-315/solution.go deleted file mode 100644 index f65a2821b1..0000000000 --- a/constraint/bls24-315/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls24-315/solver.go b/constraint/bls24-315/solver.go new file mode 100644 index 0000000000..41f89208ce --- /dev/null +++ b/constraint/bls24-315/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls24-315/system.go b/constraint/bls24-315/system.go new file mode 100644 index 0000000000..de1f206275 --- /dev/null +++ b/constraint/bls24-315/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS24_315 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/bls24-317/coeff.go b/constraint/bls24-317/coeff.go index d011d2bffb..5df92edf1b 100644 --- a/constraint/bls24-317/coeff.go +++ b/constraint/bls24-317/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls24-317/r1cs.go b/constraint/bls24-317/r1cs.go deleted file mode 100644 index a808832fd7..0000000000 --- a/constraint/bls24-317/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS24_317 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-317/r1cs_sparse.go b/constraint/bls24-317/r1cs_sparse.go deleted file mode 100644 index 5acd85dd99..0000000000 --- a/constraint/bls24-317/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS24-317) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS24_317 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-317/r1cs_test.go b/constraint/bls24-317/r1cs_test.go index eff1f06499..9ad8d2f586 100644 --- a/constraint/bls24-317/r1cs_test.go +++ b/constraint/bls24-317/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls24-317/solution.go b/constraint/bls24-317/solution.go deleted file mode 100644 index d26c56f931..0000000000 --- a/constraint/bls24-317/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls24-317/solver.go b/constraint/bls24-317/solver.go new file mode 100644 index 0000000000..e1dcad13e9 --- /dev/null +++ b/constraint/bls24-317/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls24-317/system.go b/constraint/bls24-317/system.go new file mode 100644 index 0000000000..f8ac490de2 --- /dev/null +++ b/constraint/bls24-317/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS24_317 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/blueprint.go b/constraint/blueprint.go new file mode 100644 index 0000000000..1471d9d3e9 --- /dev/null +++ b/constraint/blueprint.go @@ -0,0 +1,73 @@ +package constraint + +type BlueprintID uint32 + +// Blueprint enable representing heterogenous constraints or instructions in a constraint system +// in a memory efficient way. Blueprints essentially help the frontend/ to "compress" +// constraints or instructions, and specify for the solving (or zksnark setup) part how to +// "decompress" and optionally "solve" the associated wires. +type Blueprint interface { + // CalldataSize return the number of calldata input this blueprint expects. + // If this is unknown at compile time, implementation must return -1 and store + // the actual number of inputs in the first index of the calldata. + CalldataSize() int + + // NbConstraints return the number of constraints this blueprint creates. + NbConstraints() int + + // NbOutputs return the number of output wires this blueprint creates. + NbOutputs(inst Instruction) int + + // WireWalker returns a function that walks the wires appearing in the blueprint. + // This is used by the level builder to build a dependency graph between instructions. + WireWalker(inst Instruction) func(cb func(wire uint32)) +} + +// Solver represents the state of a constraint system solver at runtime. Blueprint can interact +// with this object to perform run time logic, solve constraints and assign values in the solution. +type Solver interface { + Field + + GetValue(cID, vID uint32) Element + GetCoeff(cID uint32) Element + SetValue(vID uint32, f Element) + IsSolved(vID uint32) bool + + // Read interprets input calldata as a LinearExpression, + // evaluates it and return the result and the number of uint32 word read. + Read(calldata []uint32) (Element, int) +} + +// BlueprintSolvable represents a blueprint that knows how to solve itself. +type BlueprintSolvable interface { + Blueprint + // Solve may return an error if the decoded constraint / calldata is unsolvable. + Solve(s Solver, instruction Instruction) error +} + +// BlueprintR1C indicates that the blueprint and associated calldata encodes a R1C +type BlueprintR1C interface { + Blueprint + CompressR1C(c *R1C, to *[]uint32) + DecompressR1C(into *R1C, instruction Instruction) +} + +// BlueprintSparseR1C indicates that the blueprint and associated calldata encodes a SparseR1C. +type BlueprintSparseR1C interface { + Blueprint + CompressSparseR1C(c *SparseR1C, to *[]uint32) + DecompressSparseR1C(into *SparseR1C, instruction Instruction) +} + +// BlueprintHint indicates that the blueprint and associated calldata encodes a hint. +type BlueprintHint interface { + Blueprint + CompressHint(h HintMapping, to *[]uint32) + DecompressHint(h *HintMapping, instruction Instruction) +} + +// Compressable represent an object that knows how to encode itself as a []uint32. +type Compressable interface { + // Compress interprets the objects as a LinearExpression and encodes it as a []uint32. + Compress(to *[]uint32) +} diff --git a/constraint/blueprint_hint.go b/constraint/blueprint_hint.go new file mode 100644 index 0000000000..ac96413ef0 --- /dev/null +++ b/constraint/blueprint_hint.go @@ -0,0 +1,94 @@ +package constraint + +import ( + "github.com/consensys/gnark/constraint/solver" +) + +type BlueprintGenericHint struct{} + +func (b *BlueprintGenericHint) DecompressHint(h *HintMapping, inst Instruction) { + // ignore first call data == nbInputs + h.HintID = solver.HintID(inst.Calldata[1]) + lenInputs := int(inst.Calldata[2]) + if cap(h.Inputs) >= lenInputs { + h.Inputs = h.Inputs[:lenInputs] + } else { + h.Inputs = make([]LinearExpression, lenInputs) + } + + j := 3 + for i := 0; i < lenInputs; i++ { + n := int(inst.Calldata[j]) // len of linear expr + j++ + if cap(h.Inputs[i]) >= n { + h.Inputs[i] = h.Inputs[i][:0] + } else { + h.Inputs[i] = make(LinearExpression, 0, n) + } + for k := 0; k < n; k++ { + h.Inputs[i] = append(h.Inputs[i], Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]}) + j += 2 + } + } + h.OutputRange.Start = inst.Calldata[j] + h.OutputRange.End = inst.Calldata[j+1] +} + +func (b *BlueprintGenericHint) CompressHint(h HintMapping, to *[]uint32) { + nbInputs := 1 // storing nb inputs + nbInputs++ // hintID + nbInputs++ // len(h.Inputs) + for i := 0; i < len(h.Inputs); i++ { + nbInputs++ // len of h.Inputs[i] + nbInputs += len(h.Inputs[i]) * 2 + } + + nbInputs += 2 // output range start / end + + (*to) = append((*to), uint32(nbInputs)) + (*to) = append((*to), uint32(h.HintID)) + (*to) = append((*to), uint32(len(h.Inputs))) + + for _, l := range h.Inputs { + (*to) = append((*to), uint32(len(l))) + for _, t := range l { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } + } + + (*to) = append((*to), h.OutputRange.Start) + (*to) = append((*to), h.OutputRange.End) +} + +func (b *BlueprintGenericHint) CalldataSize() int { + return -1 +} +func (b *BlueprintGenericHint) NbConstraints() int { + return 0 +} + +func (b *BlueprintGenericHint) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintGenericHint) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + lenInputs := int(inst.Calldata[2]) + j := 3 + for i := 0; i < lenInputs; i++ { + n := int(inst.Calldata[j]) // len of linear expr + j++ + + for k := 0; k < n; k++ { + t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]} + if !t.IsConstant() { + cb(t.VID) + } + j += 2 + } + } + for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ { + cb(k) + } + } +} diff --git a/constraint/blueprint_r1cs.go b/constraint/blueprint_r1cs.go new file mode 100644 index 0000000000..b231eda067 --- /dev/null +++ b/constraint/blueprint_r1cs.go @@ -0,0 +1,80 @@ +package constraint + +// BlueprintGenericR1C implements Blueprint and BlueprintR1C. +// Encodes +// +// L * R == 0 +type BlueprintGenericR1C struct{} + +func (b *BlueprintGenericR1C) CalldataSize() int { + // size of linear expressions are unknown. + return -1 +} +func (b *BlueprintGenericR1C) NbConstraints() int { + return 1 +} +func (b *BlueprintGenericR1C) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintGenericR1C) CompressR1C(c *R1C, to *[]uint32) { + // we store total nb inputs, len L, len R, len O, and then the "flatten" linear expressions + nbInputs := 4 + 2*(len(c.L)+len(c.R)+len(c.O)) + (*to) = append((*to), uint32(nbInputs)) + (*to) = append((*to), uint32(len(c.L)), uint32(len(c.R)), uint32(len(c.O))) + for _, t := range c.L { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } + for _, t := range c.R { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } + for _, t := range c.O { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } +} + +func (b *BlueprintGenericR1C) DecompressR1C(c *R1C, inst Instruction) { + copySlice := func(slice *LinearExpression, expectedLen, idx int) { + if cap(*slice) >= expectedLen { + (*slice) = (*slice)[:expectedLen] + } else { + (*slice) = make(LinearExpression, expectedLen, expectedLen*2) + } + for k := 0; k < expectedLen; k++ { + (*slice)[k].CID = inst.Calldata[idx] + idx++ + (*slice)[k].VID = inst.Calldata[idx] + idx++ + } + } + + lenL := int(inst.Calldata[1]) + lenR := int(inst.Calldata[2]) + lenO := int(inst.Calldata[3]) + + const offset = 4 + copySlice(&c.L, lenL, offset) + copySlice(&c.R, lenR, offset+2*lenL) + copySlice(&c.O, lenO, offset+2*(lenL+lenR)) +} + +func (b *BlueprintGenericR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + lenL := int(inst.Calldata[1]) + lenR := int(inst.Calldata[2]) + lenO := int(inst.Calldata[3]) + + appendWires := func(expectedLen, idx int) { + for k := 0; k < expectedLen; k++ { + idx++ + cb(inst.Calldata[idx]) + idx++ + } + } + + const offset = 4 + appendWires(lenL, offset) + appendWires(lenR, offset+2*lenL) + appendWires(lenO, offset+2*(lenL+lenR)) + } +} diff --git a/constraint/blueprint_scs.go b/constraint/blueprint_scs.go new file mode 100644 index 0000000000..3f0973e234 --- /dev/null +++ b/constraint/blueprint_scs.go @@ -0,0 +1,305 @@ +package constraint + +import ( + "errors" + "fmt" +) + +var ( + errDivideByZero = errors.New("division by 0") + errBoolConstrain = errors.New("boolean constraint doesn't hold") +) + +// BlueprintGenericSparseR1C implements Blueprint and BlueprintSparseR1C. +// Encodes +// +// qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0 +type BlueprintGenericSparseR1C struct { +} + +func (b *BlueprintGenericSparseR1C) CalldataSize() int { + return 9 // number of fields in SparseR1C +} +func (b *BlueprintGenericSparseR1C) NbConstraints() int { + return 1 +} + +func (b *BlueprintGenericSparseR1C) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintGenericSparseR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + cb(inst.Calldata[1]) // xb + cb(inst.Calldata[2]) // xc + } +} + +func (b *BlueprintGenericSparseR1C) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.XB, c.XC, c.QL, c.QR, c.QO, c.QM, c.QC, uint32(c.Commitment)) +} + +func (b *BlueprintGenericSparseR1C) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + + c.XA = inst.Calldata[0] + c.XB = inst.Calldata[1] + c.XC = inst.Calldata[2] + c.QL = inst.Calldata[3] + c.QR = inst.Calldata[4] + c.QO = inst.Calldata[5] + c.QM = inst.Calldata[6] + c.QC = inst.Calldata[7] + c.Commitment = CommitmentConstraint(inst.Calldata[8]) +} + +func (b *BlueprintGenericSparseR1C) Solve(s Solver, inst Instruction) error { + var c SparseR1C + b.DecompressSparseR1C(&c, inst) + if c.Commitment != NOT { + // a constraint of the form f_L - PI_2 = 0 or f_L = Comm. + // these are there for enforcing the correctness of the commitment and can be skipped in solving time + return nil + } + + var ok bool + + // constraint has at most one unsolved wire. + if !s.IsSolved(c.XA) { + // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 + u1 := s.GetCoeff(c.QL) + den := s.GetValue(c.QM, c.XB) + den = s.Add(den, u1) + den, ok = s.Inverse(den) + if !ok { + return errDivideByZero + } + v1 := s.GetValue(c.QR, c.XB) + v2 := s.GetValue(c.QO, c.XC) + num := s.Add(v1, v2) + num = s.Add(num, s.GetCoeff(c.QC)) + num = s.Mul(num, den) + num = s.Neg(num) + s.SetValue(c.XA, num) + } else if !s.IsSolved(c.XB) { + u2 := s.GetCoeff(c.QR) + den := s.GetValue(c.QM, c.XA) + den = s.Add(den, u2) + den, ok = s.Inverse(den) + if !ok { + return errDivideByZero + } + + v1 := s.GetValue(c.QL, c.XA) + v2 := s.GetValue(c.QO, c.XC) + + num := s.Add(v1, v2) + num = s.Add(num, s.GetCoeff(c.QC)) + num = s.Mul(num, den) + num = s.Neg(num) + s.SetValue(c.XB, num) + + } else if !s.IsSolved(c.XC) { + // O we solve for O + l := s.GetValue(c.QL, c.XA) + r := s.GetValue(c.QR, c.XB) + m0 := s.GetValue(c.QM, c.XA) + m1 := s.GetValue(CoeffIdOne, c.XB) + + // o = - ((m0 * m1) + l + r + c.QC) / c.O + o := s.Mul(m0, m1) + o = s.Add(o, l) + o = s.Add(o, r) + o = s.Add(o, s.GetCoeff(c.QC)) + + den := s.GetCoeff(c.QO) + den, ok = s.Inverse(den) + if !ok { + return errDivideByZero + } + o = s.Mul(o, den) + o = s.Neg(o) + + s.SetValue(c.XC, o) + } else { + // all wires are solved, we verify that the constraint hold. + // this can happen when all wires are from hints or if the constraint is an assertion. + return b.checkConstraint(&c, s) + } + return nil +} + +func (b *BlueprintGenericSparseR1C) checkConstraint(c *SparseR1C, s Solver) error { + l := s.GetValue(c.QL, c.XA) + r := s.GetValue(c.QR, c.XB) + m0 := s.GetValue(c.QM, c.XA) + m1 := s.GetValue(CoeffIdOne, c.XB) + m0 = s.Mul(m0, m1) + o := s.GetValue(c.QO, c.XC) + qC := s.GetCoeff(c.QC) + + t := s.Add(m0, l) + t = s.Add(t, r) + t = s.Add(t, o) + t = s.Add(t, qC) + + if !t.IsZero() { + return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + %s + %s != 0", + s.String(l), + s.String(r), + s.String(o), + s.String(m0), + s.String(qC), + ) + } + return nil +} + +// BlueprintSparseR1CMul implements Blueprint, BlueprintSolvable and BlueprintSparseR1C. +// Encodes +// +// qM⋅(xaxb) == xc +type BlueprintSparseR1CMul struct{} + +func (b *BlueprintSparseR1CMul) CalldataSize() int { + return 4 +} +func (b *BlueprintSparseR1CMul) NbConstraints() int { + return 1 +} +func (b *BlueprintSparseR1CMul) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintSparseR1CMul) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + cb(inst.Calldata[1]) // xb + cb(inst.Calldata[2]) // xc + } +} + +func (b *BlueprintSparseR1CMul) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.XB, c.XC, c.QM) +} + +func (b *BlueprintSparseR1CMul) Solve(s Solver, inst Instruction) error { + // qM⋅(xaxb) == xc + m0 := s.GetValue(inst.Calldata[3], inst.Calldata[0]) + m1 := s.GetValue(CoeffIdOne, inst.Calldata[1]) + + m0 = s.Mul(m0, m1) + + s.SetValue(inst.Calldata[2], m0) + return nil +} + +func (b *BlueprintSparseR1CMul) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + c.XA = inst.Calldata[0] + c.XB = inst.Calldata[1] + c.XC = inst.Calldata[2] + c.QO = CoeffIdMinusOne + c.QM = inst.Calldata[3] +} + +// BlueprintSparseR1CAdd implements Blueprint, BlueprintSolvable and BlueprintSparseR1C. +// Encodes +// +// qL⋅xa + qR⋅xb + qC == xc +type BlueprintSparseR1CAdd struct{} + +func (b *BlueprintSparseR1CAdd) CalldataSize() int { + return 6 +} +func (b *BlueprintSparseR1CAdd) NbConstraints() int { + return 1 +} +func (b *BlueprintSparseR1CAdd) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintSparseR1CAdd) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + cb(inst.Calldata[1]) // xb + cb(inst.Calldata[2]) // xc + } +} + +func (b *BlueprintSparseR1CAdd) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.XB, c.XC, c.QL, c.QR, c.QC) +} + +func (blueprint *BlueprintSparseR1CAdd) Solve(s Solver, inst Instruction) error { + // a + b + k == c + a := s.GetValue(inst.Calldata[3], inst.Calldata[0]) + b := s.GetValue(inst.Calldata[4], inst.Calldata[1]) + k := s.GetCoeff(inst.Calldata[5]) + + a = s.Add(a, b) + a = s.Add(a, k) + + s.SetValue(inst.Calldata[2], a) + return nil +} + +func (b *BlueprintSparseR1CAdd) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + c.XA = inst.Calldata[0] + c.XB = inst.Calldata[1] + c.XC = inst.Calldata[2] + c.QL = inst.Calldata[3] + c.QR = inst.Calldata[4] + c.QO = CoeffIdMinusOne + c.QC = inst.Calldata[5] +} + +// BlueprintSparseR1CBool implements Blueprint, BlueprintSolvable and BlueprintSparseR1C. +// Encodes +// +// qL⋅xa + qM⋅(xa*xa) == 0 +// that is v + -v*v == 0 +type BlueprintSparseR1CBool struct{} + +func (b *BlueprintSparseR1CBool) CalldataSize() int { + return 3 +} +func (b *BlueprintSparseR1CBool) NbConstraints() int { + return 1 +} +func (b *BlueprintSparseR1CBool) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintSparseR1CBool) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + } +} + +func (b *BlueprintSparseR1CBool) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.QL, c.QM) +} + +func (blueprint *BlueprintSparseR1CBool) Solve(s Solver, inst Instruction) error { + // all wires are already solved, we just check the constraint. + v1 := s.GetValue(inst.Calldata[1], inst.Calldata[0]) + v2 := s.GetValue(inst.Calldata[2], inst.Calldata[0]) + v := s.GetValue(CoeffIdOne, inst.Calldata[0]) + v = s.Mul(v, v2) + v = s.Add(v1, v) + if !v.IsZero() { + return errBoolConstrain + } + return nil +} + +func (b *BlueprintSparseR1CBool) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + c.XA = inst.Calldata[0] + c.XB = c.XA + c.QL = inst.Calldata[1] + c.QM = inst.Calldata[2] +} diff --git a/constraint/bn254/coeff.go b/constraint/bn254/coeff.go index 2835fd1555..da49b0e68b 100644 --- a/constraint/bn254/coeff.go +++ b/constraint/bn254/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bn254/r1cs.go b/constraint/bn254/r1cs.go deleted file mode 100644 index 895009f09a..0000000000 --- a/constraint/bn254/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BN254 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bn254/r1cs_sparse.go b/constraint/bn254/r1cs_sparse.go deleted file mode 100644 index 4a9df0856e..0000000000 --- a/constraint/bn254/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BN254) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BN254 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bn254/r1cs_test.go b/constraint/bn254/r1cs_test.go index a772621bcc..8c8ba5da95 100644 --- a/constraint/bn254/r1cs_test.go +++ b/constraint/bn254/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bn254/solution.go b/constraint/bn254/solution.go deleted file mode 100644 index 3346fc9f1c..0000000000 --- a/constraint/bn254/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bn254/solver.go b/constraint/bn254/solver.go new file mode 100644 index 0000000000..5038f2ef4b --- /dev/null +++ b/constraint/bn254/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bn254/system.go b/constraint/bn254/system.go new file mode 100644 index 0000000000..949e532dbe --- /dev/null +++ b/constraint/bn254/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BN254 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/bw6-633/coeff.go b/constraint/bw6-633/coeff.go index 48b63963d7..dd1d29ad18 100644 --- a/constraint/bw6-633/coeff.go +++ b/constraint/bw6-633/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bw6-633/r1cs.go b/constraint/bw6-633/r1cs.go deleted file mode 100644 index 415e7a19ea..0000000000 --- a/constraint/bw6-633/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BW6_633 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-633/r1cs_sparse.go b/constraint/bw6-633/r1cs_sparse.go deleted file mode 100644 index 0cba25ce7b..0000000000 --- a/constraint/bw6-633/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BW6-633) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BW6_633 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-633/r1cs_test.go b/constraint/bw6-633/r1cs_test.go index 0e07d67a8a..9f9d6a0d6f 100644 --- a/constraint/bw6-633/r1cs_test.go +++ b/constraint/bw6-633/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bw6-633/solution.go b/constraint/bw6-633/solution.go deleted file mode 100644 index aea0b28a49..0000000000 --- a/constraint/bw6-633/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bw6-633/solver.go b/constraint/bw6-633/solver.go new file mode 100644 index 0000000000..948bb4a2e4 --- /dev/null +++ b/constraint/bw6-633/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bw6-633/system.go b/constraint/bw6-633/system.go new file mode 100644 index 0000000000..f102b47290 --- /dev/null +++ b/constraint/bw6-633/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BW6_633 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/bw6-761/coeff.go b/constraint/bw6-761/coeff.go index e63c24f7ca..506ee5e586 100644 --- a/constraint/bw6-761/coeff.go +++ b/constraint/bw6-761/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bw6-761/r1cs.go b/constraint/bw6-761/r1cs.go deleted file mode 100644 index 2caf05a267..0000000000 --- a/constraint/bw6-761/r1cs.go +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - solution, err := newSolution(nbWires, hintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BW6_761 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-761/r1cs_sparse.go b/constraint/bw6-761/r1cs_sparse.go deleted file mode 100644 index 6f9f1ff048..0000000000 --- a/constraint/bw6-761/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BW6-761) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BW6_761 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-761/r1cs_test.go b/constraint/bw6-761/r1cs_test.go index ae50dafdc3..b6eae1220c 100644 --- a/constraint/bw6-761/r1cs_test.go +++ b/constraint/bw6-761/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -51,7 +52,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -79,10 +80,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -150,12 +151,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -164,8 +159,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bw6-761/solution.go b/constraint/bw6-761/solution.go deleted file mode 100644 index 54e9eee0bb..0000000000 --- a/constraint/bw6-761/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bw6-761/solver.go b/constraint/bw6-761/solver.go new file mode 100644 index 0000000000..803035e1a9 --- /dev/null +++ b/constraint/bw6-761/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bw6-761/system.go b/constraint/bw6-761/system.go new file mode 100644 index 0000000000..10e7990dfa --- /dev/null +++ b/constraint/bw6-761/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BW6_761 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/constraint/coeff.go b/constraint/coeff.go deleted file mode 100644 index 434919269f..0000000000 --- a/constraint/coeff.go +++ /dev/null @@ -1,39 +0,0 @@ -package constraint - -import ( - "math/big" -) - -// Coeff represents a term coefficient data. It is instantiated by the concrete -// constraint system implementation. -// Most of the scalar field used in gnark are on 4 uint64, so we have a clear memory overhead here. -type Coeff [6]uint64 - -// IsZero returns true if coefficient == 0 -func (z *Coeff) IsZero() bool { - return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 -} - -// CoeffEngine capability to perform arithmetic on Coeff -type CoeffEngine interface { - FromInterface(interface{}) Coeff - ToBigInt(*Coeff) *big.Int - Mul(a, b *Coeff) - Add(a, b *Coeff) - Sub(a, b *Coeff) - Neg(a *Coeff) - Inverse(a *Coeff) - One() Coeff - IsOne(*Coeff) bool - String(*Coeff) string -} - -// ids of the coefficients with simple values in any cs.coeffs slice. -// TODO @gbotrel let's keep that here for the refactoring -- and move it to concrete cs package after -const ( - CoeffIdZero = 0 - CoeffIdOne = 1 - CoeffIdTwo = 2 - CoeffIdMinusOne = 3 - CoeffIdMinusTwo = 4 -) diff --git a/constraint/commitment.go b/constraint/commitment.go index bc6a0e8d0a..2e812ce906 100644 --- a/constraint/commitment.go +++ b/constraint/commitment.go @@ -1,62 +1,109 @@ package constraint import ( + "github.com/consensys/gnark/constraint/solver" "math/big" - - "github.com/consensys/gnark/backend/hint" ) const CommitmentDst = "bsb22-commitment" -type Commitment struct { - Committed []int // sorted list of id's of committed variables - NbPrivateCommitted int - HintID hint.ID // TODO @gbotrel we probably don't need that here - CommitmentIndex int - CommittedAndCommitment []int // sorted list of id's of committed variables AND the commitment itself +type Groth16Commitment struct { + PublicAndCommitmentCommitted []int // PublicAndCommitmentCommitted sorted list of id's of public and commitment committed wires + PrivateCommitted []int // PrivateCommitted sorted list of id's of private/internal committed wires + CommitmentIndex int // CommitmentIndex the wire index of the commitment + HintID solver.HintID + NbPublicCommitted int } -func (i *Commitment) NbPublicCommitted() int { - return i.NbCommitted() - i.NbPrivateCommitted +type PlonkCommitment struct { + Committed []int // sorted list of id's of committed variables in groth16. in plonk, list of indexes of constraints defining committed values + CommitmentIndex int // CommitmentIndex index of the constraint defining the commitment + HintID solver.HintID } -func (i *Commitment) NbCommitted() int { - return len(i.Committed) +type Commitment interface{} +type Commitments interface{ CommitmentIndexes() []int } + +type Groth16Commitments []Groth16Commitment +type PlonkCommitments []PlonkCommitment + +func (c Groth16Commitments) CommitmentIndexes() []int { + commitmentWires := make([]int, len(c)) + for i := range c { + commitmentWires[i] = c[i].CommitmentIndex + } + return commitmentWires } -func (i *Commitment) Is() bool { - return len(i.Committed) != 0 +func (c PlonkCommitments) CommitmentIndexes() []int { + commitmentWires := make([]int, len(c)) + for i := range c { + commitmentWires[i] = c[i].CommitmentIndex + } + return commitmentWires } -// NewCommitment initialize a Commitment object -// - committed are the sorted wireID to commit to (without duplicate) -// - nbPublicCommited is the number of public inputs among the commited wireIDs -func NewCommitment(committed []int, nbPublicCommitted int) Commitment { - return Commitment{ - Committed: committed, - NbPrivateCommitted: len(committed) - nbPublicCommitted, +func (c Groth16Commitments) GetPrivateCommitted() [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = c[i].PrivateCommitted } + return res } -func (i *Commitment) SerializeCommitment(privateCommitment []byte, publicCommitted []*big.Int, fieldByteLen int) []byte { +// GetPublicAndCommitmentCommitted returns the list of public and commitment committed wires +// if committedTranslationList is not nil, commitment indexes are translated into their relative positions on the list plus the offset +func (c Groth16Commitments) GetPublicAndCommitmentCommitted(committedTranslationList []int, offset int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, len(c[i].PublicAndCommitmentCommitted)) + copy(res[i], c[i].GetPublicCommitted()) + translatedCommitmentCommitted := res[i][c[i].NbPublicCommitted:] + commitmentCommitted := c[i].GetCommitmentCommitted() + // convert commitment indexes to verifier understandable ones + if committedTranslationList == nil { + copy(translatedCommitmentCommitted, commitmentCommitted) + } else { + k := 0 + for j := range translatedCommitmentCommitted { + for committedTranslationList[k] != commitmentCommitted[j] { + k++ + } // find it in the translation list + translatedCommitmentCommitted[j] = k + offset + } + } + } + return res +} + +func SerializeCommitment(privateCommitment []byte, publicCommitted []*big.Int, fieldByteLen int) []byte { res := make([]byte, len(privateCommitment)+len(publicCommitted)*fieldByteLen) copy(res, privateCommitment) offset := len(privateCommitment) - for j, inJ := range publicCommitted { - offset += j * fieldByteLen + for _, inJ := range publicCommitted { inJ.FillBytes(res[offset : offset+fieldByteLen]) + offset += fieldByteLen } return res } -// PrivateToPublic returns indexes of variables which are private to the constraint system, but public to Groth16. That is, private committed variables and the commitment itself -func (i *Commitment) PrivateToPublic() []int { - return i.CommittedAndCommitment[i.NbPublicCommitted():] +func NewCommitments(t SystemType) Commitments { + switch t { + case SystemR1CS: + return Groth16Commitments{} + case SystemSparseR1CS: + return PlonkCommitments{} + } + panic("unknown cs type") +} + +func (c Groth16Commitment) GetPublicCommitted() []int { + return c.PublicAndCommitmentCommitted[:c.NbPublicCommitted] } -func (i *Commitment) PrivateCommitted() []int { - return i.Committed[i.NbPublicCommitted():] +func (c Groth16Commitment) GetCommitmentCommitted() []int { + return c.PublicAndCommitmentCommitted[c.NbPublicCommitted:] } diff --git a/constraint/core.go b/constraint/core.go new file mode 100644 index 0000000000..e54fab3149 --- /dev/null +++ b/constraint/core.go @@ -0,0 +1,459 @@ +package constraint + +import ( + "fmt" + "math/big" + "strconv" + "sync" + + "github.com/blang/semver/v4" + "github.com/consensys/gnark" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/profile" +) + +type SystemType uint16 + +const ( + SystemUnknown SystemType = iota + SystemR1CS + SystemSparseR1CS +) + +// PackedInstruction is the lowest element of a constraint system. It stores just enough data to +// reconstruct a constraint of any shape or a hint at solving time. +type PackedInstruction struct { + // BlueprintID maps this instruction to a blueprint + BlueprintID BlueprintID + + // ConstraintOffset stores the starting constraint ID of this instruction. + // Might not be strictly necessary; but speeds up solver for instructions that represents + // multiple constraints. + ConstraintOffset uint32 + + // WireOffset stores the starting internal wire ID of this instruction. Blueprints may use this + // and refer to output wires by their offset. + // For example, if a blueprint declared 5 outputs, the first output wire will be WireOffset, + // the last one WireOffset+4. + WireOffset uint32 + + // The constraint system stores a single []uint32 calldata slice. StartCallData + // points to the starting index in the mentioned slice. This avoid storing a slice per + // instruction (3 * uint64 in memory). + StartCallData uint64 +} + +// Unpack returns the instruction corresponding to the packed instruction. +func (pi PackedInstruction) Unpack(cs *System) Instruction { + + blueprint := cs.Blueprints[pi.BlueprintID] + cSize := blueprint.CalldataSize() + if cSize < 0 { + // by convention, we store nbInputs < 0 for non-static input length. + cSize = int(cs.CallData[pi.StartCallData]) + } + + return Instruction{ + ConstraintOffset: pi.ConstraintOffset, + WireOffset: pi.WireOffset, + Calldata: cs.CallData[pi.StartCallData : pi.StartCallData+uint64(cSize)], + } +} + +// Instruction is the lowest element of a constraint system. It stores all the data needed to +// reconstruct a constraint of any shape or a hint at solving time. +type Instruction struct { + ConstraintOffset uint32 + WireOffset uint32 + Calldata []uint32 +} + +// System contains core elements for a constraint System +type System struct { + // serialization header + GnarkVersion string + ScalarField string + + Type SystemType + + Instructions []PackedInstruction + Blueprints []Blueprint + CallData []uint32 // huge slice. + + // can be != than len(instructions) + NbConstraints int + + // number of internal wires + NbInternalVariables int + + // input wires names + Public, Secret []string + + // logs (added with system.Println, resolved when solver sets a value to a wire) + Logs []LogEntry + + // debug info contains stack trace (including line number) of a call to a system.API that + // results in an unsolved constraint + DebugInfo []LogEntry + SymbolTable debug.SymbolTable + // maps constraint id to debugInfo id + // several constraints may point to the same debug info + MDebug map[int]int + + // maps hintID to hint string identifier + MHintsDependencies map[solver.HintID]string + + // each level contains independent constraints and can be parallelized + // it is guaranteed that all dependencies for constraints in a level l are solved + // in previous levels + // TODO @gbotrel these are currently updated after we add a constraint. + // but in case the object is built from a serialized representation + // we need to init the level builder lbWireLevel from the existing constraints. + Levels [][]int + + // scalar field + q *big.Int `cbor:"-"` + bitLen int `cbor:"-"` + + // level builder + lbWireLevel []int `cbor:"-"` // at which level we solve a wire. init at -1. + lbOutputs []uint32 `cbor:"-"` // wire outputs for current constraint. + + CommitmentInfo Commitments + + genericHint BlueprintID +} + +// NewSystem initialize the common structure among constraint system +func NewSystem(scalarField *big.Int, capacity int, t SystemType) System { + system := System{ + Type: t, + SymbolTable: debug.NewSymbolTable(), + MDebug: map[int]int{}, + GnarkVersion: gnark.Version.String(), + ScalarField: scalarField.Text(16), + MHintsDependencies: make(map[solver.HintID]string), + q: new(big.Int).Set(scalarField), + bitLen: scalarField.BitLen(), + Instructions: make([]PackedInstruction, 0, capacity), + CallData: make([]uint32, 0, capacity*8), + lbOutputs: make([]uint32, 0, 256), + lbWireLevel: make([]int, 0, capacity), + Levels: make([][]int, 0, capacity/2), + CommitmentInfo: NewCommitments(t), + } + + system.genericHint = system.AddBlueprint(&BlueprintGenericHint{}) + return system +} + +// GetNbInstructions returns the number of instructions in the system +func (system *System) GetNbInstructions() int { + return len(system.Instructions) +} + +// GetInstruction returns the instruction at index id +func (system *System) GetInstruction(id int) Instruction { + return system.Instructions[id].Unpack(system) +} + +// AddBlueprint adds a blueprint to the system and returns its ID +func (system *System) AddBlueprint(b Blueprint) BlueprintID { + system.Blueprints = append(system.Blueprints, b) + return BlueprintID(len(system.Blueprints) - 1) +} + +func (system *System) GetNbSecretVariables() int { + return len(system.Secret) +} +func (system *System) GetNbPublicVariables() int { + return len(system.Public) +} +func (system *System) GetNbInternalVariables() int { + return system.NbInternalVariables +} + +// CheckSerializationHeader parses the scalar field and gnark version headers +// +// This is meant to be use at the deserialization step, and will error for illegal values +func (system *System) CheckSerializationHeader() error { + // check gnark version + binaryVersion := gnark.Version + objectVersion, err := semver.Parse(system.GnarkVersion) + if err != nil { + return fmt.Errorf("when parsing gnark version: %w", err) + } + + if binaryVersion.Compare(objectVersion) != 0 { + log := logger.Logger() + log.Warn().Str("binary", binaryVersion.String()).Str("object", objectVersion.String()).Msg("gnark version (binary) mismatch with constraint system. there are no guarantees on compatibilty") + } + + // TODO @gbotrel maintain version changes and compare versions properly + // (ie if major didn't change,we shouldn't have a compatibility issue) + + scalarField := new(big.Int) + _, ok := scalarField.SetString(system.ScalarField, 16) + if !ok { + return fmt.Errorf("when parsing serialized modulus: %s", system.ScalarField) + } + curveID := utils.FieldToCurve(scalarField) + if curveID == ecc.UNKNOWN && scalarField.Cmp(tinyfield.Modulus()) != 0 { + return fmt.Errorf("unsupported scalar field %s", scalarField.Text(16)) + } + system.q = new(big.Int).Set(scalarField) + system.bitLen = system.q.BitLen() + return nil +} + +// GetNbVariables return number of internal, secret and public variables +func (system *System) GetNbVariables() (internal, secret, public int) { + return system.NbInternalVariables, system.GetNbSecretVariables(), system.GetNbPublicVariables() +} + +func (system *System) Field() *big.Int { + return new(big.Int).Set(system.q) +} + +// bitLen returns the number of bits needed to represent a fr.Element +func (system *System) FieldBitLen() int { + return system.bitLen +} + +func (system *System) AddInternalVariable() (idx int) { + idx = system.NbInternalVariables + system.GetNbPublicVariables() + system.GetNbSecretVariables() + system.NbInternalVariables++ + return idx +} + +func (system *System) AddPublicVariable(name string) (idx int) { + idx = system.GetNbPublicVariables() + system.Public = append(system.Public, name) + return idx +} + +func (system *System) AddSecretVariable(name string) (idx int) { + idx = system.GetNbSecretVariables() + system.GetNbPublicVariables() + system.Secret = append(system.Secret, name) + return idx +} + +func (system *System) AddSolverHint(f solver.Hint, id solver.HintID, input []LinearExpression, nbOutput int) (internalVariables []int, err error) { + if nbOutput <= 0 { + return nil, fmt.Errorf("hint function must return at least one output") + } + + var name string + if id == solver.GetHintID(f) { + name = solver.GetHintName(f) + } else { + name = strconv.Itoa(int(id)) + } + + // register the hint as dependency + if registeredName, ok := system.MHintsDependencies[id]; ok { + // hint already registered, let's ensure string registeredName matches + if registeredName != name { + return nil, fmt.Errorf("hint dependency registration failed; %s previously register with same UUID as %s", name, registeredName) + } + } else { + system.MHintsDependencies[id] = name + } + + // prepare wires + internalVariables = make([]int, nbOutput) + for i := 0; i < len(internalVariables); i++ { + internalVariables[i] = system.AddInternalVariable() + } + + // associate these wires with the solver hint + hm := HintMapping{ + HintID: id, + Inputs: input, + OutputRange: struct { + Start uint32 + End uint32 + }{ + uint32(internalVariables[0]), + uint32(internalVariables[len(internalVariables)-1]) + 1, + }, + } + + blueprint := system.Blueprints[system.genericHint] + + // get []uint32 from the pool + calldata := getBuffer() + + blueprint.(BlueprintHint).CompressHint(hm, calldata) + + system.AddInstruction(system.genericHint, *calldata) + + // return []uint32 to the pool + putBuffer(calldata) + + return +} + +func (system *System) AddCommitment(c Commitment) error { + switch v := c.(type) { + case Groth16Commitment: + system.CommitmentInfo = append(system.CommitmentInfo.(Groth16Commitments), v) + case PlonkCommitment: + system.CommitmentInfo = append(system.CommitmentInfo.(PlonkCommitments), v) + default: + return fmt.Errorf("unknown commitment type %T", v) + } + return nil +} + +func (system *System) AddLog(l LogEntry) { + system.Logs = append(system.Logs, l) +} + +func (system *System) AttachDebugInfo(debugInfo DebugInfo, constraintID []int) { + system.DebugInfo = append(system.DebugInfo, LogEntry(debugInfo)) + id := len(system.DebugInfo) - 1 + for _, cID := range constraintID { + system.MDebug[cID] = id + } +} + +// VariableToString implements Resolver +func (system *System) VariableToString(vID int) string { + nbPublic := system.GetNbPublicVariables() + nbSecret := system.GetNbSecretVariables() + + if vID < nbPublic { + return system.Public[vID] + } + vID -= nbPublic + if vID < nbSecret { + return system.Secret[vID] + } + vID -= nbSecret + return fmt.Sprintf("v%d", vID) // TODO @gbotrel vs strconv.Itoa. +} + +func (cs *System) AddR1C(c R1C, bID BlueprintID) int { + profile.RecordConstraint() + + blueprint := cs.Blueprints[bID] + + // get a []uint32 from a pool + calldata := getBuffer() + + // compress the R1C into a []uint32 and add the instruction + blueprint.(BlueprintR1C).CompressR1C(&c, calldata) + cs.AddInstruction(bID, *calldata) + + // release the []uint32 to the pool + putBuffer(calldata) + + return cs.NbConstraints - 1 +} + +func (cs *System) AddSparseR1C(c SparseR1C, bID BlueprintID) int { + profile.RecordConstraint() + + blueprint := cs.Blueprints[bID] + + // get a []uint32 from a pool + calldata := getBuffer() + + // compress the SparceR1C into a []uint32 and add the instruction + blueprint.(BlueprintSparseR1C).CompressSparseR1C(&c, calldata) + + cs.AddInstruction(bID, *calldata) + + // release the []uint32 to the pool + putBuffer(calldata) + + return cs.NbConstraints - 1 +} + +func (cs *System) AddInstruction(bID BlueprintID, calldata []uint32) []uint32 { + // set the offsets + pi := PackedInstruction{ + StartCallData: uint64(len(cs.CallData)), + ConstraintOffset: uint32(cs.NbConstraints), + WireOffset: uint32(cs.NbInternalVariables + cs.GetNbPublicVariables() + cs.GetNbSecretVariables()), + BlueprintID: bID, + } + + // append the call data + cs.CallData = append(cs.CallData, calldata...) + + // update the total number of constraints + blueprint := cs.Blueprints[pi.BlueprintID] + cs.NbConstraints += blueprint.NbConstraints() + + // add the output wires + inst := pi.Unpack(cs) + nbOutputs := blueprint.NbOutputs(inst) + var wires []uint32 + for i := 0; i < nbOutputs; i++ { + wires = append(wires, uint32(cs.AddInternalVariable())) + } + + // add the instruction + cs.Instructions = append(cs.Instructions, pi) + + // update the instruction dependency tree + cs.updateLevel(len(cs.Instructions)-1, blueprint.WireWalker(inst)) + + return wires +} + +// GetNbConstraints returns the number of constraints +func (cs *System) GetNbConstraints() int { + return cs.NbConstraints +} + +func (cs *System) CheckUnconstrainedWires() error { + // TODO @gbotrel + return nil +} + +func (cs *System) GetR1CIterator() R1CIterator { + return R1CIterator{cs: cs} +} + +func (cs *System) GetSparseR1CIterator() SparseR1CIterator { + return SparseR1CIterator{cs: cs} +} + +func (cs *System) GetCommitments() Commitments { + return cs.CommitmentInfo +} + +// bufPool is a pool of buffers used by getBuffer and putBuffer. +// It is used to avoid allocating buffers for each constraint. +var bufPool = sync.Pool{ + New: func() interface{} { + r := make([]uint32, 0, 20) + return &r + }, +} + +// getBuffer returns a buffer of at least the given size. +// The buffer is taken from the pool if it is large enough, +// otherwise a new buffer is allocated. +// Caller must call putBuffer when done with the buffer. +func getBuffer() *[]uint32 { + to := bufPool.Get().(*[]uint32) + *to = (*to)[:0] + return to +} + +// putBuffer returns a buffer to the pool. +func putBuffer(buf *[]uint32) { + if buf == nil { + panic("invalid entry in putBuffer") + } + bufPool.Put(buf) +} diff --git a/constraint/field.go b/constraint/field.go new file mode 100644 index 0000000000..f63bc910e2 --- /dev/null +++ b/constraint/field.go @@ -0,0 +1,53 @@ +package constraint + +import ( + "encoding/binary" + "math/big" +) + +// Element represents a term coefficient data. It is instantiated by the concrete +// constraint system implementation. +// Most of the scalar field used in gnark are on 4 uint64, so we have a clear memory overhead here. +type Element [6]uint64 + +// IsZero returns true if coefficient == 0 +func (z *Element) IsZero() bool { + return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 +} + +// Bytes return the Element as a big-endian byte slice +func (z *Element) Bytes() [48]byte { + var b [48]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + return b +} + +// SetBytes sets the Element from a big-endian byte slice +func (z *Element) SetBytes(b [48]byte) { + z[0] = binary.BigEndian.Uint64(b[40:48]) + z[1] = binary.BigEndian.Uint64(b[32:40]) + z[2] = binary.BigEndian.Uint64(b[24:32]) + z[3] = binary.BigEndian.Uint64(b[16:24]) + z[4] = binary.BigEndian.Uint64(b[8:16]) + z[5] = binary.BigEndian.Uint64(b[0:8]) +} + +// Field capability to perform arithmetic on Coeff +type Field interface { + FromInterface(interface{}) Element + ToBigInt(Element) *big.Int + Mul(a, b Element) Element + Add(a, b Element) Element + Sub(a, b Element) Element + Neg(a Element) Element + Inverse(a Element) (Element, bool) + One() Element + IsOne(Element) bool + String(Element) string + Uint64(Element) (uint64, bool) +} diff --git a/constraint/hint.go b/constraint/hint.go index 0e92d6afd6..ac1dbf96ba 100644 --- a/constraint/hint.go +++ b/constraint/hint.go @@ -1,14 +1,14 @@ package constraint import ( - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" ) -// Hint represents a solver hint -// it enables the solver to compute a Wire with a function provided at solving time -// using pre-defined inputs -type Hint struct { - ID hint.ID // hint function id - Inputs []LinearExpression // terms to inject in the hint function - Wires []int // IDs of wires the hint outputs map to +// HintMapping mark a list of output variables to be computed using provided hint and inputs. +type HintMapping struct { + HintID solver.HintID // Hint function id + Inputs []LinearExpression // Terms to inject in the hint function + OutputRange struct { // IDs of wires the hint outputs map to + Start, End uint32 + } } diff --git a/constraint/level_builder.go b/constraint/level_builder.go index 80c24614b8..e6cc47cbe6 100644 --- a/constraint/level_builder.go +++ b/constraint/level_builder.go @@ -8,34 +8,31 @@ package constraint // // We build a graph of dependency; we say that a wire is solved at a level l // --> l = max(level_of_dependencies(wire)) + 1 -func (system *System) updateLevel(cID int, c Iterable) { - system.lbOutputs = system.lbOutputs[:0] - system.lbHints = map[*Hint]struct{}{} +func (system *System) updateLevel(iID int, walkWires func(cb func(wire uint32))) { level := -1 - wireIterator := c.WireIterator() - for wID := wireIterator(); wID != -1; wID = wireIterator() { - // iterate over all wires of the R1C - system.processWire(uint32(wID), &level) - } + + // process all wires of the instruction + walkWires(func(wire uint32) { + system.processWire(wire, &level) + }) // level = max(dependencies) + 1 level++ // mark output wire with level for _, wireID := range system.lbOutputs { - for int(wireID) >= len(system.lbWireLevel) { - // we didn't encounter this wire yet, we need to grow b.wireLevels - system.lbWireLevel = append(system.lbWireLevel, -1) - } system.lbWireLevel[wireID] = level } // we can't skip levels, so appending is fine. if level >= len(system.Levels) { - system.Levels = append(system.Levels, []int{cID}) + system.Levels = append(system.Levels, []int{iID}) } else { - system.Levels[level] = append(system.Levels[level], cID) + system.Levels[level] = append(system.Levels[level], iID) } + // clean the table. NB! Do not remove or move, this is required to make the + // compilation deterministic. + system.lbOutputs = system.lbOutputs[:0] } func (system *System) processWire(wireID uint32, maxLevel *int) { @@ -53,32 +50,6 @@ func (system *System) processWire(wireID uint32, maxLevel *int) { } return } - // we don't know how to solve this wire; it's either THE wire we have to solve or a hint. - if h, ok := system.MHints[int(wireID)]; ok { - // check that we didn't process that hint already; performance wise, if many wires in a - // constraint are the output of the same hint, and input to parent hint are themselves - // computed with a hint, we can suffer. - // (nominal case: not too many different hints involved for a single constraint) - if _, ok := system.lbHints[h]; ok { - // skip - return - } - system.lbHints[h] = struct{}{} - - for _, hwid := range h.Wires { - system.lbOutputs = append(system.lbOutputs, uint32(hwid)) - } - for _, in := range h.Inputs { - for _, t := range in { - if !t.IsConstant() { - system.processWire(t.VID, maxLevel) - } - } - } - - return - } - - // it's the missing wire + // this wire is an output to the instruction system.lbOutputs = append(system.lbOutputs, wireID) } diff --git a/constraint/level_builder_test.go b/constraint/level_builder_test.go new file mode 100644 index 0000000000..c7489409e1 --- /dev/null +++ b/constraint/level_builder_test.go @@ -0,0 +1,42 @@ +package constraint_test + +import ( + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +func idHint(_ *big.Int, in []*big.Int, out []*big.Int) error { + if len(in) != len(out) { + return fmt.Errorf("in/out length mismatch %d≠%d", len(in), len(out)) + } + for i := range in { + out[i].Set(in[i]) + } + return nil +} + +type idHintCircuit struct { + X frontend.Variable +} + +func (c *idHintCircuit) Define(api frontend.API) error { + x, err := api.Compiler().NewHint(idHint, 1, c.X) + if err != nil { + return err + } + api.AssertIsEqual(x[0], c.X) + return nil +} + +func TestIdHint(t *testing.T) { + solver.RegisterHint(idHint) + assignment := idHintCircuit{0} + test.NewAssert(t).SolvingSucceeded(&idHintCircuit{}, &assignment, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254)) +} diff --git a/constraint/linear_expression.go b/constraint/linear_expression.go index 32a66b7843..20fb9e9986 100644 --- a/constraint/linear_expression.go +++ b/constraint/linear_expression.go @@ -29,3 +29,10 @@ func (l LinearExpression) String(r Resolver) string { sbb.WriteLinearExpression(l) return sbb.String() } + +func (l LinearExpression) Compress(to *[]uint32) { + (*to) = append((*to), uint32(len(l))) + for i := 0; i < len(l); i++ { + (*to) = append((*to), l[i].CID, l[i].VID) + } +} diff --git a/constraint/r1cs.go b/constraint/r1cs.go index 02c327e30e..b58ff95d40 100644 --- a/constraint/r1cs.go +++ b/constraint/r1cs.go @@ -14,161 +14,143 @@ package constraint -import ( - "errors" - "strconv" - "strings" - - "github.com/consensys/gnark/logger" -) - type R1CS interface { ConstraintSystem - // AddConstraint adds a constraint to the system and returns its id + // AddR1C adds a constraint to the system and returns its id // This does not check for validity of the constraint. - // If a debugInfo parameter is provided, it will be appended to the debug info structure - // and will grow the memory usage of the constraint system. - AddConstraint(r1c R1C, debugInfo ...DebugInfo) int + AddR1C(r1c R1C, bID BlueprintID) int - // GetConstraints return the list of R1C and a helper for pretty printing. + // GetR1Cs return the list of R1C // See StringBuilder for more info. // ! this is an experimental API. - GetConstraints() ([]R1C, Resolver) -} - -// R1CS describes a set of R1C constraint -type R1CSCore struct { - System - Constraints []R1C -} + GetR1Cs() []R1C -// GetNbConstraints returns the number of constraints -func (r1cs *R1CSCore) GetNbConstraints() int { - return len(r1cs.Constraints) + // GetR1CIterator returns an R1CIterator to iterate on the R1C constraints of the system. + GetR1CIterator() R1CIterator } -func (r1cs *R1CSCore) UpdateLevel(cID int, c Iterable) { - r1cs.updateLevel(cID, c) +// R1CIterator facilitates iterating through R1C constraints. +type R1CIterator struct { + R1C + cs *System + n int } -// IsValid perform post compilation checks on the Variables -// -// 1. checks that all user inputs are referenced in at least one constraint -// 2. checks that all hints are constrained -func (r1cs *R1CSCore) CheckUnconstrainedWires() error { - - // TODO @gbotrel add unit test for that. - - inputConstrained := make([]bool, r1cs.GetNbSecretVariables()+r1cs.GetNbPublicVariables()) - // one wire does not need to be constrained - inputConstrained[0] = true - cptInputs := len(inputConstrained) - 1 // marking 1 wire as already constrained // TODO @gbotrel check that - if cptInputs == 0 { - return errors.New("invalid constraint system: no input defined") - } - - cptHints := len(r1cs.MHints) - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the linear expressions and mark our inputs / hints as constrained - processLinearExpression := func(l LinearExpression) { - for _, t := range l { - if t.CoeffID() == CoeffIdZero { - // ignore zero coefficient, as it does not constraint the Variable - // though, we may want to flag that IF the Variable doesn't appear else where - continue - } - vID := t.WireID() - if vID < len(inputConstrained) { - if !inputConstrained[vID] { - inputConstrained[vID] = true - cptInputs-- - } - } else { - // internal variable, let's check if it's a hint - if _, ok := r1cs.MHints[vID]; ok { - if !mHintsConstrained[vID] { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - - } - } - for _, r1c := range r1cs.Constraints { - processLinearExpression(r1c.L) - processLinearExpression(r1c.R) - processLinearExpression(r1c.O) - - if cptHints|cptInputs == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptInputs != 0 { - sbb.WriteString(strconv.Itoa(cptInputs)) - sbb.WriteString(" unconstrained input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { - if !inputConstrained[i] { - if i < len(r1cs.Public) { - sbb.WriteString(r1cs.Public[i]) - } else { - sbb.WriteString(r1cs.Secret[i-len(r1cs.Public)]) - } - - sbb.WriteByte('\n') - cptInputs-- - } - } - sbb.WriteByte('\n') - return errors.New(sbb.String()) - } - - if cptHints != 0 { - // TODO @gbotrel @ivokub investigate --> emulated hints seems to go in this path a lot. - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints; i.e. wire created through NewHint() but doesn't not appear in the constraint system") - sbb.WriteByte('\n') - log := logger.Logger() - log.Warn().Err(errors.New(sbb.String())).Send() +// Next returns the next R1C or nil if end. Caller must not store the result since the +// same memory space is re-used for subsequent calls to Next. +func (it *R1CIterator) Next() *R1C { + if it.n >= it.cs.GetNbInstructions() { return nil - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) } - return errors.New(sbb.String()) + inst := it.cs.Instructions[it.n] + it.n++ + blueprint := it.cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(BlueprintR1C); ok { + bc.DecompressR1C(&it.R1C, inst.Unpack(it.cs)) + return &it.R1C + } + return it.Next() } +// // IsValid perform post compilation checks on the Variables +// // +// // 1. checks that all user inputs are referenced in at least one constraint +// // 2. checks that all hints are constrained +// func (r1cs *R1CSCore) CheckUnconstrainedWires() error { +// return nil + +// // TODO @gbotrel add unit test for that. + +// inputConstrained := make([]bool, r1cs.GetNbSecretVariables()+r1cs.GetNbPublicVariables()) +// // one wire does not need to be constrained +// inputConstrained[0] = true +// cptInputs := len(inputConstrained) - 1 // marking 1 wire as already constrained // TODO @gbotrel check that +// if cptInputs == 0 { +// return errors.New("invalid constraint system: no input defined") +// } + +// cptHints := len(r1cs.MHints) +// mHintsConstrained := make(map[int]bool) + +// // for each constraint, we check the linear expressions and mark our inputs / hints as constrained +// processLinearExpression := func(l LinearExpression) { +// for _, t := range l { +// if t.CoeffID() == CoeffIdZero { +// // ignore zero coefficient, as it does not constraint the Variable +// // though, we may want to flag that IF the Variable doesn't appear else where +// continue +// } +// vID := t.WireID() +// if vID < len(inputConstrained) { +// if !inputConstrained[vID] { +// inputConstrained[vID] = true +// cptInputs-- +// } +// } else { +// // internal variable, let's check if it's a hint +// if _, ok := r1cs.MHints[vID]; ok { +// if !mHintsConstrained[vID] { +// mHintsConstrained[vID] = true +// cptHints-- +// } +// } +// } + +// } +// } +// for _, r1c := range r1cs.Constraints { +// processLinearExpression(r1c.L) +// processLinearExpression(r1c.R) +// processLinearExpression(r1c.O) + +// if cptHints|cptInputs == 0 { +// return nil // we can stop. +// } + +// } + +// // something is a miss, we build the error string +// var sbb strings.Builder +// if cptInputs != 0 { +// sbb.WriteString(strconv.Itoa(cptInputs)) +// sbb.WriteString(" unconstrained input(s):") +// sbb.WriteByte('\n') +// for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { +// if !inputConstrained[i] { +// if i < len(r1cs.Public) { +// sbb.WriteString(r1cs.Public[i]) +// } else { +// sbb.WriteString(r1cs.Secret[i-len(r1cs.Public)]) +// } + +// sbb.WriteByte('\n') +// cptInputs-- +// } +// } +// sbb.WriteByte('\n') +// return errors.New(sbb.String()) +// } + +// if cptHints != 0 { +// // TODO @gbotrel @ivokub investigate --> emulated hints seems to go in this path a lot. +// sbb.WriteString(strconv.Itoa(cptHints)) +// sbb.WriteString(" unconstrained hints; i.e. wire created through NewHint() but doesn't not appear in the constraint system") +// sbb.WriteByte('\n') +// log := logger.Logger() +// log.Warn().Err(errors.New(sbb.String())).Send() +// return nil +// // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some +// // debugInfo to find where a hint was declared (and not constrained) +// } +// return errors.New(sbb.String()) +// } + // R1C used to compute the wires type R1C struct { L, R, O LinearExpression } -// WireIterator implements constraint.Iterable -func (r1c *R1C) WireIterator() func() int { - curr := 0 - return func() int { - if curr < len(r1c.L) { - curr++ - return r1c.L[curr-1].WireID() - } - if curr < len(r1c.L)+len(r1c.R) { - curr++ - return r1c.R[curr-1-len(r1c.L)].WireID() - } - if curr < len(r1c.L)+len(r1c.R)+len(r1c.O) { - curr++ - return r1c.O[curr-1-len(r1c.L)-len(r1c.R)].WireID() - } - return -1 - } -} - // String formats a R1C as L⋅R == O func (r1c *R1C) String(r Resolver) string { sbb := NewStringBuilder(r) diff --git a/constraint/r1cs_sparse.go b/constraint/r1cs_sparse.go index 4d99c82857..742521e9d5 100644 --- a/constraint/r1cs_sparse.go +++ b/constraint/r1cs_sparse.go @@ -14,160 +14,153 @@ package constraint -import ( - "errors" - "strconv" - "strings" -) - type SparseR1CS interface { ConstraintSystem - // AddConstraint adds a constraint to the sytem and returns its id - // This does not check for validity of the constraint. - // If a debugInfo parameter is provided, it will be appended to the debug info structure - // and will grow the memory usage of the constraint system. - AddConstraint(c SparseR1C, debugInfo ...DebugInfo) int + // AddSparseR1C adds a constraint to the constraint system. + AddSparseR1C(c SparseR1C, bID BlueprintID) int - // GetConstraints return the list of SparseR1C and a helper for pretty printing. + // GetSparseR1Cs return the list of SparseR1C // See StringBuilder for more info. // ! this is an experimental API. - GetConstraints() ([]SparseR1C, Resolver) -} - -// R1CS describes a set of SparseR1C constraint -// TODO @gbotrel maybe SparseR1CSCore and R1CSCore should go in code generation directly to avoid confusing this package. -type SparseR1CSCore struct { - System - Constraints []SparseR1C -} + GetSparseR1Cs() []SparseR1C -// GetNbConstraints returns the number of constraints -func (cs *SparseR1CSCore) GetNbConstraints() int { - return len(cs.Constraints) + // GetSparseR1CIterator returns an SparseR1CIterator to iterate on the SparseR1C constraints of the system. + GetSparseR1CIterator() SparseR1CIterator } -func (cs *SparseR1CSCore) UpdateLevel(cID int, c Iterable) { - cs.updateLevel(cID, c) +// SparseR1CIterator facilitates iterating through SparseR1C constraints. +type SparseR1CIterator struct { + SparseR1C + cs *System + n int } -func (system *SparseR1CSCore) CheckUnconstrainedWires() error { - // TODO @gbotrel add unit test for that. - - inputConstrained := make([]bool, system.GetNbSecretVariables()+system.GetNbPublicVariables()) - cptInputs := len(inputConstrained) - if cptInputs == 0 { - return errors.New("invalid constraint system: no input defined") +// Next returns the next SparseR1C or nil if end. Caller must not store the result since the +// same memory space is re-used for subsequent calls to Next. +func (it *SparseR1CIterator) Next() *SparseR1C { + if it.n >= it.cs.GetNbInstructions() { + return nil } - - cptHints := len(system.MHints) - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the terms and mark our inputs / hints as constrained - processTerm := func(t Term) { - - // L and M[0] handles the same wire but with a different coeff - vID := t.WireID() - if t.CoeffID() != CoeffIdZero { - if vID < len(inputConstrained) { - if !inputConstrained[vID] { - inputConstrained[vID] = true - cptInputs-- - } - } else { - // internal variable, let's check if it's a hint - if _, ok := system.MHints[vID]; ok { - vID -= (system.GetNbPublicVariables() + system.GetNbSecretVariables()) - if !mHintsConstrained[vID] { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - } - + inst := it.cs.Instructions[it.n] + it.n++ + blueprint := it.cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&it.SparseR1C, inst.Unpack(it.cs)) + return &it.SparseR1C } - for _, c := range system.Constraints { - processTerm(c.L) - processTerm(c.R) - processTerm(c.M[0]) - processTerm(c.M[1]) - processTerm(c.O) - if cptHints|cptInputs == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptInputs != 0 { - sbb.WriteString(strconv.Itoa(cptInputs)) - sbb.WriteString(" unconstrained input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { - if !inputConstrained[i] { - if i < len(system.Public) { - sbb.WriteString(system.Public[i]) - } else { - sbb.WriteString(system.Secret[i-len(system.Public)]) - } - sbb.WriteByte('\n') - cptInputs-- - } - } - sbb.WriteByte('\n') - } - - if cptHints != 0 { - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints") - sbb.WriteByte('\n') - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) - } - return errors.New(sbb.String()) + return it.Next() } -// SparseR1C used to compute the wires -// L+R+M[0]M[1]+O+k=0 -// if a Term is zero, it means the field doesn't exist (ex M=[0,0] means there is no multiplicative term) +// func (system *SparseR1CSCore) CheckUnconstrainedWires() error { +// // TODO @gbotrel add unit test for that. +// return nil +// inputConstrained := make([]bool, system.GetNbSecretVariables()+system.GetNbPublicVariables()) +// cptInputs := len(inputConstrained) +// if cptInputs == 0 { +// return errors.New("invalid constraint system: no input defined") +// } + +// cptHints := len(system.MHints) +// mHintsConstrained := make(map[int]bool) + +// // for each constraint, we check the terms and mark our inputs / hints as constrained +// processTerm := func(t Term) { + +// // L and M[0] handles the same wire but with a different coeff +// vID := t.WireID() +// if t.CoeffID() != CoeffIdZero { +// if vID < len(inputConstrained) { +// if !inputConstrained[vID] { +// inputConstrained[vID] = true +// cptInputs-- +// } +// } else { +// // internal variable, let's check if it's a hint +// if _, ok := system.MHints[vID]; ok { +// vID -= (system.GetNbPublicVariables() + system.GetNbSecretVariables()) +// if !mHintsConstrained[vID] { +// mHintsConstrained[vID] = true +// cptHints-- +// } +// } +// } +// } + +// } +// for _, c := range system.Constraints { +// processTerm(c.L) +// processTerm(c.R) +// processTerm(c.M[0]) +// processTerm(c.M[1]) +// processTerm(c.O) +// if cptHints|cptInputs == 0 { +// return nil // we can stop. +// } + +// } + +// // something is a miss, we build the error string +// var sbb strings.Builder +// if cptInputs != 0 { +// sbb.WriteString(strconv.Itoa(cptInputs)) +// sbb.WriteString(" unconstrained input(s):") +// sbb.WriteByte('\n') +// for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { +// if !inputConstrained[i] { +// if i < len(system.Public) { +// sbb.WriteString(system.Public[i]) +// } else { +// sbb.WriteString(system.Secret[i-len(system.Public)]) +// } +// sbb.WriteByte('\n') +// cptInputs-- +// } +// } +// sbb.WriteByte('\n') +// } + +// if cptHints != 0 { +// sbb.WriteString(strconv.Itoa(cptHints)) +// sbb.WriteString(" unconstrained hints") +// sbb.WriteByte('\n') +// // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some +// // debugInfo to find where a hint was declared (and not constrained) +// } +// return errors.New(sbb.String()) +// } + +type CommitmentConstraint uint32 + +const ( + NOT CommitmentConstraint = 0 + COMMITTED CommitmentConstraint = 1 + COMMITMENT CommitmentConstraint = 2 +) + +// SparseR1C represent a PlonK-ish constraint +// qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC -committed?*Bsb22Commitments-commitment?*commitmentValue == 0 type SparseR1C struct { - L, R, O Term - M [2]Term - K int // stores only the ID of the constant term that is used + XA, XB, XC uint32 + QL, QR, QO, QM, QC uint32 + Commitment CommitmentConstraint } -// WireIterator implements constraint.Iterable -func (c *SparseR1C) WireIterator() func() int { - curr := 0 - return func() int { - switch curr { - case 0: - curr++ - return c.L.WireID() - case 1: - curr++ - return c.R.WireID() - case 2: - curr++ - return c.O.WireID() - } - return -1 - } +func (c *SparseR1C) Clear() { + *c = SparseR1C{} } // String formats the constraint as qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0 func (c *SparseR1C) String(r Resolver) string { sbb := NewStringBuilder(r) - sbb.WriteTerm(c.L) + sbb.WriteTerm(Term{CID: c.QL, VID: c.XA}) sbb.WriteString(" + ") - sbb.WriteTerm(c.R) + sbb.WriteTerm(Term{CID: c.QR, VID: c.XB}) sbb.WriteString(" + ") - sbb.WriteTerm(c.O) - if qM := sbb.CoeffToString(c.M[0].CoeffID()); qM != "0" { - xa := sbb.VariableToString(c.M[0].WireID()) - xb := sbb.VariableToString(c.M[1].WireID()) + sbb.WriteTerm(Term{CID: c.QO, VID: c.XC}) + if qM := sbb.CoeffToString(int(c.QM)); qM != "0" { + xa := sbb.VariableToString(int(c.XA)) + xb := sbb.VariableToString(int(c.XB)) sbb.WriteString(" + ") sbb.WriteString(qM) sbb.WriteString("⋅(") @@ -177,7 +170,7 @@ func (c *SparseR1C) String(r Resolver) string { sbb.WriteByte(')') } sbb.WriteString(" + ") - sbb.WriteString(r.CoeffToString(c.K)) + sbb.WriteString(r.CoeffToString(int(c.QC))) sbb.WriteString(" == 0") return sbb.String() } diff --git a/constraint/r1cs_sparse_test.go b/constraint/r1cs_sparse_test.go index 024560ff93..5a0bfd647c 100644 --- a/constraint/r1cs_sparse_test.go +++ b/constraint/r1cs_sparse_test.go @@ -7,12 +7,13 @@ import ( cs "github.com/consensys/gnark/constraint/bn254" ) -func ExampleSparseR1CS_GetConstraints() { +func ExampleSparseR1CS_GetSparseR1Cs() { // build a constraint system; this is (usually) done by the frontend package // for this Example we want to manipulate the constraints and output a string representation // and build the linear expressions "manually". // note: R1CS apis are more mature; SparseR1CS apis are going to change in the next release(s). scs := cs.NewSparseR1CS(0) + blueprint := scs.AddBlueprint(&constraint.BlueprintGenericSparseR1C{}) Y := scs.AddPublicVariable("Y") X := scs.AddSecretVariable("X") @@ -20,40 +21,34 @@ func ExampleSparseR1CS_GetConstraints() { v0 := scs.AddInternalVariable() // X² // coefficients - cZero := scs.FromInterface(0) cOne := scs.FromInterface(1) - cMinusOne := scs.FromInterface(-1) cFive := scs.FromInterface(5) // X² == X * X - scs.AddConstraint(constraint.SparseR1C{ - L: scs.MakeTerm(&cZero, X), - R: scs.MakeTerm(&cZero, X), - O: scs.MakeTerm(&cMinusOne, v0), - M: [2]constraint.Term{ - scs.MakeTerm(&cOne, X), - scs.MakeTerm(&cOne, X), - }, - K: int(scs.MakeTerm(&cZero, 0).CID), - }) + scs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(X), + XB: uint32(X), + XC: uint32(v0), + QO: constraint.CoeffIdMinusOne, + QM: constraint.CoeffIdOne, + }, blueprint) // X² + 5X + 5 == Y - scs.AddConstraint(constraint.SparseR1C{ - R: scs.MakeTerm(&cOne, v0), - L: scs.MakeTerm(&cFive, X), - O: scs.MakeTerm(&cMinusOne, Y), - M: [2]constraint.Term{ - scs.MakeTerm(&cZero, v0), - scs.MakeTerm(&cZero, X), - }, - K: int(scs.MakeTerm(&cFive, 0).CID), - }) + scs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(X), + XB: uint32(v0), + XC: uint32(Y), + QO: constraint.CoeffIdMinusOne, + QL: scs.AddCoeff(cFive), + QR: scs.AddCoeff(cOne), + QC: scs.AddCoeff(cFive), + }, blueprint) // get the constraints - constraints, r := scs.GetConstraints() + constraints := scs.GetSparseR1Cs() for _, c := range constraints { - fmt.Println(c.String(r)) + fmt.Println(c.String(scs)) // for more granularity use constraint.NewStringBuilder(r) that embeds a string.Builder // and has WriteLinearExpression and WriteTerm methods. } diff --git a/constraint/r1cs_test.go b/constraint/r1cs_test.go index b2736d3073..9444b5ef53 100644 --- a/constraint/r1cs_test.go +++ b/constraint/r1cs_test.go @@ -3,16 +3,21 @@ package constraint_test import ( "fmt" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" ) -func ExampleR1CS_GetConstraints() { +func ExampleR1CS_GetR1Cs() { // build a constraint system; this is (usually) done by the frontend package // for this Example we want to manipulate the constraints and output a string representation // and build the linear expressions "manually". r1cs := cs.NewR1CS(0) + blueprint := r1cs.AddBlueprint(&constraint.BlueprintGenericR1C{}) + ONE := r1cs.AddPublicVariable("1") // the "ONE" wire Y := r1cs.AddPublicVariable("Y") X := r1cs.AddSecretVariable("X") @@ -25,35 +30,35 @@ func ExampleR1CS_GetConstraints() { cFive := r1cs.FromInterface(5) // X² == X * X - r1cs.AddConstraint(constraint.R1C{ - L: constraint.LinearExpression{r1cs.MakeTerm(&cOne, X)}, - R: constraint.LinearExpression{r1cs.MakeTerm(&cOne, X)}, - O: constraint.LinearExpression{r1cs.MakeTerm(&cOne, v0)}, - }) + r1cs.AddR1C(constraint.R1C{ + L: constraint.LinearExpression{r1cs.MakeTerm(cOne, X)}, + R: constraint.LinearExpression{r1cs.MakeTerm(cOne, X)}, + O: constraint.LinearExpression{r1cs.MakeTerm(cOne, v0)}, + }, blueprint) // X³ == X² * X - r1cs.AddConstraint(constraint.R1C{ - L: constraint.LinearExpression{r1cs.MakeTerm(&cOne, v0)}, - R: constraint.LinearExpression{r1cs.MakeTerm(&cOne, X)}, - O: constraint.LinearExpression{r1cs.MakeTerm(&cOne, v1)}, - }) + r1cs.AddR1C(constraint.R1C{ + L: constraint.LinearExpression{r1cs.MakeTerm(cOne, v0)}, + R: constraint.LinearExpression{r1cs.MakeTerm(cOne, X)}, + O: constraint.LinearExpression{r1cs.MakeTerm(cOne, v1)}, + }, blueprint) // Y == X³ + X + 5 - r1cs.AddConstraint(constraint.R1C{ - R: constraint.LinearExpression{r1cs.MakeTerm(&cOne, ONE)}, - L: constraint.LinearExpression{r1cs.MakeTerm(&cOne, Y)}, + r1cs.AddR1C(constraint.R1C{ + R: constraint.LinearExpression{r1cs.MakeTerm(cOne, ONE)}, + L: constraint.LinearExpression{r1cs.MakeTerm(cOne, Y)}, O: constraint.LinearExpression{ - r1cs.MakeTerm(&cFive, ONE), - r1cs.MakeTerm(&cOne, X), - r1cs.MakeTerm(&cOne, v1), + r1cs.MakeTerm(cFive, ONE), + r1cs.MakeTerm(cOne, X), + r1cs.MakeTerm(cOne, v1), }, - }) + }, blueprint) // get the constraints - constraints, r := r1cs.GetConstraints() + constraints := r1cs.GetR1Cs() for _, r1c := range constraints { - fmt.Println(r1c.String(r)) + fmt.Println(r1c.String(r1cs)) // for more granularity use constraint.NewStringBuilder(r) that embeds a string.Builder // and has WriteLinearExpression and WriteTerm methods. } @@ -63,3 +68,38 @@ func ExampleR1CS_GetConstraints() { // v0 ⋅ X == v1 // Y ⋅ 1 == 5 + X + v1 } + +func ExampleR1CS_Solve() { + // build a constraint system and a witness; + ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &cubic{}) + w, _ := frontend.NewWitness(&cubic{X: 3, Y: 35}, ecc.BN254.ScalarField()) + + _solution, _ := ccs.Solve(w) + + // concrete solution + solution := _solution.(*cs.R1CSSolution) + + // solution vector should have [1, 3, 35, 9, 27] + for _, v := range solution.W { + fmt.Println(v.String()) + } + + // Output: + // 1 + // 3 + // 35 + // 9 + // 27 +} + +type cubic struct { + X, Y frontend.Variable +} + +// Define declares the circuit constraints +// x**3 + x + 5 == y +func (circuit *cubic) Define(api frontend.API) error { + x3 := api.Mul(circuit.X, circuit.X, circuit.X) + api.AssertIsEqual(circuit.Y, api.Add(x3, circuit.X, 5)) + return nil +} diff --git a/constraint/solver/hint.go b/constraint/solver/hint.go new file mode 100644 index 0000000000..5ffc3f1f31 --- /dev/null +++ b/constraint/solver/hint.go @@ -0,0 +1,101 @@ +package solver + +import ( + "hash/fnv" + "math/big" + "reflect" + "runtime" +) + +// HintID is a unique identifier for a hint function used for lookup. +type HintID uint32 + +// Hint allows to define computations outside of a circuit. +// +// It defines an annotated hint function; the number of inputs and outputs injected at solving +// time is defined in the circuit (compile time). +// +// For example: +// +// b := api.NewHint(hint, 2, a) +// --> at solving time, hint is going to be invoked with 1 input (a) and is expected to return 2 outputs +// b[0] and b[1]. +// +// Usually, it is expected that computations in circuits are performed on +// variables. However, in some cases defining the computations in circuits may be +// complicated or computationally expensive. By using hints, the computations are +// performed outside of the circuit on integers (compared to the frontend.Variable +// values inside the circuits) and the result of a hint function is assigned to a +// newly created variable in a circuit. +// +// As the computations are performed outside of the circuit, then the correctness of +// the result is not guaranteed. This also means that the result of a hint function +// is unconstrained by default, leading to failure while composing circuit proof. +// Thus, it is the circuit developer responsibility to verify the correctness hint +// result by adding necessary constraints in the circuit. +// +// As an example, lets say the hint function computes a factorization of a +// semiprime n: +// +// p, q <- hint(n) st. p * q = n +// +// into primes p and q. Then, the circuit developer needs to assert in the circuit +// that p*q indeed equals to n: +// +// n == p * q. +// +// However, if the hint function is incorrectly defined (e.g. in the previous +// example, it returns 1 and n instead of p and q), then the assertion may still +// hold, but the constructed proof is semantically invalid. Thus, the user +// constructing the proof must be extremely cautious when using hints. +// +// # Using hint functions in circuits +// +// To use a hint function in a circuit, the developer first needs to define a hint +// function hintFn according to the Function interface. Then, in a circuit, the +// developer applies the hint function with frontend.API.NewHint(hintFn, vars...), +// where vars are the variables the hint function will be applied to (and +// correspond to the argument inputs in the Function type) which returns a new +// unconstrained variable. The returned variables must be constrained using +// frontend.API.Assert[.*] methods. +// +// As explained, the hints are essentially black boxes from the circuit point of +// view and thus the defined hints in circuits are not used when constructing a +// proof. To allow the particular hint functions to be used during proof +// construction, the user needs to supply a solver.Option indicating the +// enabled hints. Such options can be obtained by a call to +// solver.WithHints(hintFns...), where hintFns are the corresponding hint +// functions. +// +// # Using hint functions in gadgets +// +// Similar considerations apply for hint functions used in gadgets as in +// user-defined circuits. However, listing all hint functions used in a particular +// gadget for constructing solver.Option puts high overhead for the user to +// enable all necessary hints. +// +// For that, this package also provides a registry of trusted hint functions. When +// a gadget registers a hint function, then it is automatically enabled during +// proof computation and the prover does not need to provide a corresponding +// proving option. +// +// In the init() method of the gadget, call the method RegisterHint(hintFn) function on +// the hint function hintFn to register a hint function in the package registry. +type Hint func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error + +// GetHintID is a reference function for computing the hint ID based on a function name +func GetHintID(fn Hint) HintID { + hf := fnv.New32a() + name := GetHintName(fn) + + // TODO relying on name to derive UUID is risky; if fn is an anonymous func, wil be package.glob..funcN + // and if new anonymous functions are added in the package, N may change, so will UUID. + hf.Write([]byte(name)) // #nosec G104 -- does not err + + return HintID(hf.Sum32()) +} + +func GetHintName(fn Hint) string { + fnptr := reflect.ValueOf(fn).Pointer() + return runtime.FuncForPC(fnptr).Name() +} diff --git a/constraint/solver/hint_registry.go b/constraint/solver/hint_registry.go new file mode 100644 index 0000000000..7021de5024 --- /dev/null +++ b/constraint/solver/hint_registry.go @@ -0,0 +1,88 @@ +package solver + +import ( + "fmt" + "math/big" + "sync" + + "github.com/consensys/gnark/logger" +) + +func init() { + RegisterHint(InvZeroHint) +} + +var ( + registry = make(map[HintID]Hint) + registryM sync.RWMutex +) + +// RegisterHint registers a hint function in the global registry. +func RegisterHint(hintFns ...Hint) { + registryM.Lock() + defer registryM.Unlock() + for _, hintFn := range hintFns { + key := GetHintID(hintFn) + name := GetHintName(hintFn) + if _, ok := registry[key]; ok { + log := logger.Logger() + log.Warn().Str("name", name).Msg("function registered multiple times") + return + } + registry[key] = hintFn + } +} + +func GetRegisteredHint(key HintID) Hint { + return registry[key] +} + +func RegisterNamedHint(hintFn Hint, key HintID) { + registryM.Lock() + defer registryM.Unlock() + if _, ok := registry[key]; ok { + panic(fmt.Errorf("hint id %d already taken", key)) + } + registry[key] = hintFn +} + +// GetRegisteredHints returns all registered hint functions. +func GetRegisteredHints() []Hint { + registryM.RLock() + defer registryM.RUnlock() + ret := make([]Hint, 0, len(registry)) + for _, v := range registry { + ret = append(ret, v) + } + return ret +} + +func cloneMap[K comparable, V any](src map[K]V) map[K]V { + res := make(map[K]V, len(registry)) + for k, v := range src { + res[k] = v + } + return res +} + +func cloneHintRegistry() map[HintID]Hint { + registryM.Lock() + defer registryM.Unlock() + return cloneMap(registry) +} + +// InvZeroHint computes the value 1/a for the single input a. If a == 0, returns 0. +func InvZeroHint(q *big.Int, inputs []*big.Int, results []*big.Int) error { + result := results[0] + + // save input + result.Set(inputs[0]) + + // a == 0, return + if result.IsUint64() && result.Uint64() == 0 { + return nil + } + + result.ModInverse(result, q) + return nil +} diff --git a/constraint/solver/options.go b/constraint/solver/options.go new file mode 100644 index 0000000000..d7aa7cac4e --- /dev/null +++ b/constraint/solver/options.go @@ -0,0 +1,67 @@ +package solver + +import ( + "github.com/consensys/gnark/logger" + "github.com/rs/zerolog" +) + +// Option defines option for altering the behavior of a constraint system +// solver (Solve() method). See the descriptions of functions returning instances +// of this type for implemented options. +type Option func(*Config) error + +// Config is the configuration for the solver with the options applied. +type Config struct { + HintFunctions map[HintID]Hint // defaults to all built-in hint functions + Logger zerolog.Logger // defaults to gnark.Logger +} + +// WithHints is a solver option that specifies additional hint functions to be used +// by the constraint solver. +func WithHints(hintFunctions ...Hint) Option { + log := logger.Logger() + return func(opt *Config) error { + // it is an error to register hint function several times, but as the + // prover already checks it then omit here. + for _, h := range hintFunctions { + uuid := GetHintID(h) + if _, ok := opt.HintFunctions[uuid]; ok { + log.Warn().Int("hintID", int(uuid)).Str("name", GetHintName(h)).Msg("duplicate hint function") + } else { + opt.HintFunctions[uuid] = h + } + } + return nil + } +} + +// OverrideHint forces the solver to use provided hint function for given id. +func OverrideHint(id HintID, f Hint) Option { + return func(opt *Config) error { + opt.HintFunctions[id] = f + return nil + } +} + +// WithLogger is a prover option that specifies zerolog.Logger as a destination for the +// logs printed by api.Println(). By default, uses gnark/logger. +// zerolog.Nop() will disable logging +func WithLogger(l zerolog.Logger) Option { + return func(opt *Config) error { + opt.Logger = l + return nil + } +} + +// NewConfig returns a default SolverConfig with given prover options opts applied. +func NewConfig(opts ...Option) (Config, error) { + log := logger.Logger() + opt := Config{Logger: log} + opt.HintFunctions = cloneHintRegistry() + for _, option := range opts { + if err := option(&opt); err != nil { + return Config{}, err + } + } + return opt, nil +} diff --git a/constraint/system.go b/constraint/system.go index a389212b18..19f8167f32 100644 --- a/constraint/system.go +++ b/constraint/system.go @@ -1,30 +1,29 @@ package constraint import ( - "fmt" "io" "math/big" - "github.com/blang/semver/v4" - "github.com/consensys/gnark" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/internal/tinyfield" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/constraint/solver" ) // ConstraintSystem interface that all constraint systems implement. type ConstraintSystem interface { io.WriterTo io.ReaderFrom - CoeffEngine + Field + Resolver + CustomizableSystem // IsSolved returns nil if given witness solves the constraint system and error otherwise - IsSolved(witness witness.Witness, opts ...backend.ProverOption) error + // Deprecated: use _, err := Solve(...) instead + IsSolved(witness witness.Witness, opts ...solver.Option) error + + // Solve attempts to solves the constraint system using provided witness. + // Returns an error if the witness does not allow all the constraints to be satisfied. + // Returns a typed solution (R1CSSolution or SparseR1CSSolution) and nil otherwise. + Solve(witness witness.Witness, opts ...solver.Option) (any, error) // GetNbVariables return number of internal, secret and public Variables // Deprecated: use GetNbSecretVariables() instead @@ -34,6 +33,7 @@ type ConstraintSystem interface { GetNbSecretVariables() int GetNbPublicVariables() int + GetNbInstructions() int GetNbConstraints() int GetNbCoefficients() int @@ -46,254 +46,44 @@ type ConstraintSystem interface { // AddSolverHint adds a hint to the solver such that the output variables will be computed // using a call to output := f(input...) at solve time. - AddSolverHint(f hint.Function, input []LinearExpression, nbOutput int) (internalVariables []int, err error) + // Providing the function f is optional. If it is provided, id will be ignored and one will be derived from f's name. + // Otherwise, the provided id will be used to register the hint with, + AddSolverHint(f solver.Hint, id solver.HintID, input []LinearExpression, nbOutput int) (internalVariables []int, err error) AddCommitment(c Commitment) error - AddGkr(gkr GkrInfo) error + GetCommitments() Commitments AddLog(l LogEntry) // MakeTerm returns a new Term. The constraint system may store coefficients in a map, so // calls to this function will grow the memory usage of the constraint system. - MakeTerm(coeff *Coeff, variableID int) Term + MakeTerm(coeff Element, variableID int) Term + + // AddCoeff adds a coefficient to the underlying constraint system. The system will not store duplicate, + // but is not purging for unused coeff either, so this grows memory usage. + AddCoeff(coeff Element) uint32 NewDebugInfo(errName string, i ...interface{}) DebugInfo // AttachDebugInfo enables attaching debug information to multiple constraints. - // This is more efficient than using the AddConstraint(.., debugInfo) since it will store the + // This is more efficient than using the AddR1C(.., debugInfo) since it will store the // debug information only once. AttachDebugInfo(debugInfo DebugInfo, constraintID []int) // CheckUnconstrainedWires returns and error if the constraint system has wires that are not uniquely constrained. // This is experimental. CheckUnconstrainedWires() error -} - -type Iterable interface { - // WireIterator returns a new iterator to iterate over the wires of the implementer (usually, a constraint) - // Call to next() returns the next wireID of the Iterable object and -1 when iteration is over. - // - // For example a R1C constraint with L, R, O linear expressions, each of size 2, calling several times - // next := r1c.WireIterator(); - // for wID := next(); wID != -1; wID = next() {} - // // will return in order L[0],L[1],R[0],R[1],O[0],O[1],-1 - WireIterator() (next func() int) -} - -var _ Iterable = &SparseR1C{} -var _ Iterable = &R1C{} - -// System contains core elements for a constraint System -type System struct { - // serialization header - GnarkVersion string - ScalarField string - - // number of internal wires - NbInternalVariables int - - // input wires names - Public, Secret []string - - // logs (added with system.Println, resolved when solver sets a value to a wire) - Logs []LogEntry - - // debug info contains stack trace (including line number) of a call to a system.API that - // results in an unsolved constraint - DebugInfo []LogEntry - SymbolTable debug.SymbolTable - // maps constraint id to debugInfo id - // several constraints may point to the same debug info - MDebug map[int]int - - MHints map[int]*Hint // maps wireID to hint - MHintsDependencies map[hint.ID]string // maps hintID to hint string identifier - - // each level contains independent constraints and can be parallelized - // it is guaranteed that all dependencies for constraints in a level l are solved - // in previous levels - // TODO @gbotrel these are currently updated after we add a constraint. - // but in case the object is built from a serialized representation - // we need to init the level builder lbWireLevel from the existing constraints. - Levels [][]int - - // scalar field - q *big.Int `cbor:"-"` - bitLen int `cbor:"-"` - - // level builder - lbWireLevel []int `cbor:"-"` // at which level we solve a wire. init at -1. - lbOutputs []uint32 `cbor:"-"` // wire outputs for current constraint. - lbHints map[*Hint]struct{} `cbor:"-"` // hints we processed in current round - - CommitmentInfo Commitment - GkrInfo GkrInfo -} - -// NewSystem initialize the common structure among constraint system -func NewSystem(scalarField *big.Int) System { - return System{ - SymbolTable: debug.NewSymbolTable(), - MDebug: map[int]int{}, - GnarkVersion: gnark.Version.String(), - ScalarField: scalarField.Text(16), - MHints: make(map[int]*Hint), - MHintsDependencies: make(map[hint.ID]string), - q: new(big.Int).Set(scalarField), - bitLen: scalarField.BitLen(), - lbHints: map[*Hint]struct{}{}, - } -} - -func (system *System) GetNbSecretVariables() int { - return len(system.Secret) -} -func (system *System) GetNbPublicVariables() int { - return len(system.Public) -} -func (system *System) GetNbInternalVariables() int { - return system.NbInternalVariables -} - -// CheckSerializationHeader parses the scalar field and gnark version headers -// -// This is meant to be use at the deserialization step, and will error for illegal values -func (system *System) CheckSerializationHeader() error { - // check gnark version - binaryVersion := gnark.Version - objectVersion, err := semver.Parse(system.GnarkVersion) - if err != nil { - return fmt.Errorf("when parsing gnark version: %w", err) - } - - if binaryVersion.Compare(objectVersion) != 0 { - log := logger.Logger() - log.Warn().Str("binary", binaryVersion.String()).Str("object", objectVersion.String()).Msg("gnark version (binary) mismatch with constraint system. there are no guarantees on compatibilty") - } - - // TODO @gbotrel maintain version changes and compare versions properly - // (ie if major didn't change,we shouldn't have a compat issue) - - scalarField := new(big.Int) - _, ok := scalarField.SetString(system.ScalarField, 16) - if !ok { - return fmt.Errorf("when parsing serialized modulus: %s", system.ScalarField) - } - curveID := utils.FieldToCurve(scalarField) - if curveID == ecc.UNKNOWN && scalarField.Cmp(tinyfield.Modulus()) != 0 { - return fmt.Errorf("unsupported scalard field %s", scalarField.Text(16)) - } - system.q = new(big.Int).Set(scalarField) - system.bitLen = system.q.BitLen() - return nil -} - -// GetNbVariables return number of internal, secret and public variables -func (system *System) GetNbVariables() (internal, secret, public int) { - return system.NbInternalVariables, system.GetNbSecretVariables(), system.GetNbPublicVariables() -} - -func (system *System) Field() *big.Int { - return new(big.Int).Set(system.q) -} - -// bitLen returns the number of bits needed to represent a fr.Element -func (system *System) FieldBitLen() int { - return system.bitLen -} - -func (system *System) AddInternalVariable() (idx int) { - idx = system.NbInternalVariables + system.GetNbPublicVariables() + system.GetNbSecretVariables() - system.NbInternalVariables++ - return idx -} - -func (system *System) AddPublicVariable(name string) (idx int) { - idx = system.GetNbPublicVariables() - system.Public = append(system.Public, name) - return idx -} - -func (system *System) AddSecretVariable(name string) (idx int) { - idx = system.GetNbSecretVariables() + system.GetNbPublicVariables() - system.Secret = append(system.Secret, name) - return idx -} - -func (system *System) AddSolverHint(f hint.Function, input []LinearExpression, nbOutput int) (internalVariables []int, err error) { - if nbOutput <= 0 { - return nil, fmt.Errorf("hint function must return at least one output") - } - // register the hint as dependency - hintUUID, hintID := hint.UUID(f), hint.Name(f) - if id, ok := system.MHintsDependencies[hintUUID]; ok { - // hint already registered, let's ensure string id matches - if id != hintID { - return nil, fmt.Errorf("hint dependency registration failed; %s previously register with same UUID as %s", hintID, id) - } - } else { - system.MHintsDependencies[hintUUID] = hintID - } - - // prepare wires - internalVariables = make([]int, nbOutput) - for i := 0; i < len(internalVariables); i++ { - internalVariables[i] = system.AddInternalVariable() - } - - // associate these wires with the solver hint - ch := &Hint{ID: hintUUID, Inputs: input, Wires: internalVariables} - for _, vID := range internalVariables { - system.MHints[vID] = ch - } - - return -} - -func (system *System) AddGkr(gkr GkrInfo) error { - if system.GkrInfo.Is() { - return fmt.Errorf("currently only one GKR sub-circuit per SNARK is supported") - } - - system.GkrInfo = gkr - return nil -} - -func (system *System) AddCommitment(c Commitment) error { - if system.CommitmentInfo.Is() { - return fmt.Errorf("currently only one commitment per circuit is supported") - } - - system.CommitmentInfo = c - - return nil -} - -func (system *System) AddLog(l LogEntry) { - system.Logs = append(system.Logs, l) -} + GetInstruction(int) Instruction -func (system *System) AttachDebugInfo(debugInfo DebugInfo, constraintID []int) { - system.DebugInfo = append(system.DebugInfo, LogEntry(debugInfo)) - id := len(system.DebugInfo) - 1 - for _, cID := range constraintID { - system.MDebug[cID] = id - } + GetCoefficient(i int) Element } -// VariableToString implements Resolver -func (system *System) VariableToString(vID int) string { - nbPublic := system.GetNbPublicVariables() - nbSecret := system.GetNbSecretVariables() +type CustomizableSystem interface { + // AddBlueprint registers the given blueprint and returns its id. This should be called only once per blueprint. + AddBlueprint(b Blueprint) BlueprintID - if vID < nbPublic { - return system.Public[vID] - } - vID -= nbPublic - if vID < nbSecret { - return system.Secret[vID] - } - vID -= nbSecret - return fmt.Sprintf("v%d", vID) // TODO @gbotrel vs strconv.Itoa. + // AddInstruction adds an instruction to the system and returns a list of created wires + // if the blueprint declared any outputs. + AddInstruction(bID BlueprintID, calldata []uint32) []uint32 } diff --git a/constraint/term.go b/constraint/term.go index 182329e576..857edd0da0 100644 --- a/constraint/term.go +++ b/constraint/term.go @@ -18,6 +18,15 @@ import ( "math" ) +// ids of the coefficients with simple values in any cs.coeffs slice. +const ( + CoeffIdZero = iota + CoeffIdOne + CoeffIdTwo + CoeffIdMinusOne + CoeffIdMinusTwo +) + // Term represents a coeff * variable in a constraint system type Term struct { CID, VID uint32 @@ -44,3 +53,12 @@ func (t Term) String(r Resolver) string { sbb.WriteTerm(t) return sbb.String() } + +// implements constraint.Compressable + +// Compress compresses the term into a slice of uint32 words. +// For compatibility with test engine and LinearExpression, the term is encoded as: +// 1, CID, VID (i.e a LinearExpression with a single term) +func (t Term) Compress(to *[]uint32) { + (*to) = append((*to), 1, t.CID, t.VID) +} diff --git a/constraint/tinyfield/coeff.go b/constraint/tinyfield/coeff.go index 3bd799d820..3aec8da553 100644 --- a/constraint/tinyfield/coeff.go +++ b/constraint/tinyfield/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/tinyfield/r1cs.go b/constraint/tinyfield/r1cs.go deleted file mode 100644 index e181affa19..0000000000 --- a/constraint/tinyfield/r1cs.go +++ /dev/null @@ -1,452 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/tinyfield/r1cs_sparse.go b/constraint/tinyfield/r1cs_sparse.go deleted file mode 100644 index 4ad9f0023c..0000000000 --- a/constraint/tinyfield/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.tinyfield) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/tinyfield/r1cs_test.go b/constraint/tinyfield/r1cs_test.go index 09bb5f0968..5015c06972 100644 --- a/constraint/tinyfield/r1cs_test.go +++ b/constraint/tinyfield/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -51,7 +52,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -79,10 +80,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -150,12 +151,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -164,8 +159,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/tinyfield/solution.go b/constraint/tinyfield/solution.go deleted file mode 100644 index 2f5eb4caec..0000000000 --- a/constraint/tinyfield/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/tinyfield/solver.go b/constraint/tinyfield/solver.go new file mode 100644 index 0000000000..7cb146906c --- /dev/null +++ b/constraint/tinyfield/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + fr "github.com/consensys/gnark/internal/tinyfield" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/tinyfield/system.go b/constraint/tinyfield/system.go new file mode 100644 index 0000000000..d75471e445 --- /dev/null +++ b/constraint/tinyfield/system.go @@ -0,0 +1,374 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + fr "github.com/consensys/gnark/internal/tinyfield" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.UNKNOWN +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/debug/debug.go b/debug/debug.go index f41bac9e83..32b891c38a 100644 --- a/debug/debug.go +++ b/debug/debug.go @@ -56,7 +56,7 @@ func writeStack(sbb *strings.Builder, forceClean ...bool) { if strings.Contains(frame.File, "test/engine.go") { continue } - if strings.Contains(frame.File, "gnark/frontend") { + if strings.Contains(frame.File, "gnark/frontend/cs") { continue } file = filepath.Base(file) @@ -75,5 +75,8 @@ func writeStack(sbb *strings.Builder, forceClean ...bool) { if strings.HasSuffix(function, "Define") { break } + if strings.HasSuffix(function, "callDeferred") { + break + } } } diff --git a/debug/symbol_table.go b/debug/symbol_table.go index a840aeb7ae..dc25c3e8d8 100644 --- a/debug/symbol_table.go +++ b/debug/symbol_table.go @@ -61,7 +61,7 @@ func (st *SymbolTable) CollectStack() []int { if strings.Contains(frame.File, "test/engine.go") { continue } - if strings.Contains(frame.File, "gnark/frontend") { + if strings.Contains(frame.File, "gnark/frontend/cs") { continue } frame.File = filepath.Base(frame.File) @@ -76,6 +76,9 @@ func (st *SymbolTable) CollectStack() []int { if strings.HasSuffix(function, "Define") { break } + if strings.HasSuffix(function, "callDeferred") { + break + } } return r } diff --git a/debug_test.go b/debug_test.go index d3efb39e1e..620e9860d0 100644 --- a/debug_test.go +++ b/debug_test.go @@ -9,6 +9,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" @@ -46,11 +48,11 @@ func TestPrintln(t *testing.T) { witness.B = 11 var expected bytes.Buffer - expected.WriteString("debug_test.go:28 > 13 is the addition\n") - expected.WriteString("debug_test.go:30 > 26 42\n") - expected.WriteString("debug_test.go:32 > bits 1\n") - expected.WriteString("debug_test.go:33 > circuit {A: 2, B: 11}\n") - expected.WriteString("debug_test.go:37 > m .*\n") + expected.WriteString("debug_test.go:30 > 13 is the addition\n") + expected.WriteString("debug_test.go:32 > 26 42\n") + expected.WriteString("debug_test.go:34 > bits 1\n") + expected.WriteString("debug_test.go:35 > circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:39 > m .*\n") { trace, _ := getGroth16Trace(&circuit, &witness) @@ -94,9 +96,14 @@ func TestTraceDivBy0(t *testing.T) { { _, err := getPlonkTrace(&circuit, &witness) assert.Error(err) - assert.Contains(err.Error(), "constraint #1 is not satisfied: [inverse] 1/0 < ∞") - assert.Contains(err.Error(), "(*divBy0Trace).Define") - assert.Contains(err.Error(), "debug_test.go:") + if debug.Debug { + assert.Contains(err.Error(), "constraint #1 is not satisfied: [inverse] 1/0 < ∞") + assert.Contains(err.Error(), "(*divBy0Trace).Define") + assert.Contains(err.Error(), "debug_test.go:") + } else { + assert.Contains(err.Error(), "constraint #1 is not satisfied: division by 0") + } + } } @@ -123,17 +130,26 @@ func TestTraceNotEqual(t *testing.T) { { _, err := getGroth16Trace(&circuit, &witness) assert.Error(err) - assert.Contains(err.Error(), "constraint #0 is not satisfied: [assertIsEqual] 1 == 66") - assert.Contains(err.Error(), "(*notEqualTrace).Define") - assert.Contains(err.Error(), "debug_test.go:") + if debug.Debug { + assert.Contains(err.Error(), "constraint #0 is not satisfied: [assertIsEqual] 1 == 66") + assert.Contains(err.Error(), "(*notEqualTrace).Define") + assert.Contains(err.Error(), "debug_test.go:") + } else { + assert.Contains(err.Error(), "constraint #0 is not satisfied: 1 ⋅ 1 != 66") + } } { _, err := getPlonkTrace(&circuit, &witness) assert.Error(err) - assert.Contains(err.Error(), "constraint #1 is not satisfied: [assertIsEqual] 1 + -66 == 0") - assert.Contains(err.Error(), "(*notEqualTrace).Define") - assert.Contains(err.Error(), "debug_test.go:") + if debug.Debug { + assert.Contains(err.Error(), "constraint #1 is not satisfied: [assertIsEqual] 1 == 66") + assert.Contains(err.Error(), "(*notEqualTrace).Define") + assert.Contains(err.Error(), "debug_test.go:") + } else { + assert.Contains(err.Error(), "constraint #1 is not satisfied: qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → 1 + -66 + 0 + 0 + 0 != 0") + } + } } @@ -158,7 +174,7 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { return "", err } log := zerolog.New(&zerolog.ConsoleWriter{Out: &buf, NoColor: true, PartsExclude: []string{zerolog.LevelFieldName, zerolog.TimestampFieldName}}) - _, err = plonk.Prove(ccs, pk, sw, backend.WithCircuitLogger(log)) + _, err = plonk.Prove(ccs, pk, sw, backend.WithSolverOptions(solver.WithLogger(log))) return buf.String(), err } @@ -179,6 +195,6 @@ func getGroth16Trace(circuit, w frontend.Circuit) (string, error) { return "", err } log := zerolog.New(&zerolog.ConsoleWriter{Out: &buf, NoColor: true, PartsExclude: []string{zerolog.LevelFieldName, zerolog.TimestampFieldName}}) - _, err = groth16.Prove(ccs, pk, sw, backend.WithCircuitLogger(log)) + _, err = groth16.Prove(ccs, pk, sw, backend.WithSolverOptions(solver.WithLogger(log))) return buf.String(), err } diff --git a/doc.go b/doc.go index 697944b7d7..b871010f37 100644 --- a/doc.go +++ b/doc.go @@ -22,7 +22,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" ) -var Version = semver.MustParse("0.8.0") +var Version = semver.MustParse("0.8.1-alpha") // Curves return the curves supported by gnark func Curves() []ecc.ID { diff --git a/examples/plonk/main.go b/examples/plonk/main.go index 36e355ccde..80aeea5eaf 100644 --- a/examples/plonk/main.go +++ b/examples/plonk/main.go @@ -80,7 +80,7 @@ func main() { // create the necessary data for KZG. // This is a toy example, normally the trusted setup to build ZKG - // has been ran before. + // has been run before. // The size of the data in KZG should be the closest power of 2 bounding // // above max(nbConstraints, nbVariables). _r1cs := ccs.(*cs.SparseR1CS) diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 0b3b5000ba..27b47f4fda 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -20,7 +20,7 @@ import ( tedwards "github.com/consensys/gnark-crypto/ecc/twistededwards" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/accumulator/merkle" - "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/signature/eddsa" ) diff --git a/frontend/api.go b/frontend/api.go index 3645f04197..a54e2ec3c0 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -19,7 +19,7 @@ package frontend import ( "math/big" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" ) // API represents the available functions to circuit developers @@ -30,9 +30,16 @@ type API interface { // Add returns res = i1+i2+...in Add(i1, i2 Variable, in ...Variable) Variable - // MulAcc sets and return a = a + (b*c) - // ! may mutate a without allocating a new result - // ! always use MulAcc(...) result for correctness + // MulAcc sets and return a = a + (b*c). + // + // ! The method may mutate a without allocating a new result. If the input + // is used elsewhere, then first initialize new variable, for example by + // doing: + // + // acopy := api.Mul(a, 1) + // acopy = MulAcc(acopy, b, c) + // + // ! But it may not modify a, always use MulAcc(...) result for correctness. MulAcc(a, b, c Variable) Variable // Neg returns -i @@ -122,7 +129,7 @@ type API interface { // NewHint is a shortcut to api.Compiler().NewHint() // Deprecated: use api.Compiler().NewHint() instead - NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + NewHint(f solver.Hint, nbOutputs int, inputs ...Variable) ([]Variable, error) // ConstantValue is a shortcut to api.Compiler().ConstantValue() // Deprecated: use api.Compiler().ConstantValue() instead diff --git a/frontend/builder.go b/frontend/builder.go index 9393a45dc7..680a789b7a 100644 --- a/frontend/builder.go +++ b/frontend/builder.go @@ -3,8 +3,8 @@ package frontend import ( "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend/schema" ) @@ -12,6 +12,8 @@ type NewBuilder func(*big.Int, CompileConfig) (Builder, error) // Compiler represents a constraint system compiler type Compiler interface { + constraint.CustomizableSystem + // MarkBoolean sets (but do not constraint!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. @@ -36,7 +38,8 @@ type Compiler interface { // manually in the circuit. Failing to do so leads to solver failure. // // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs - NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + NewHint(f solver.Hint, nbOutputs int, inputs ...Variable) ([]Variable, error) + NewHintForId(id solver.HintID, nbOutputs int, inputs ...Variable) ([]Variable, error) // ConstantValue returns the big.Int value of v and true if op is a success. // nil and false if failure. This API returns a boolean to allow for future refactoring @@ -49,13 +52,18 @@ type Compiler interface { // FieldBitLen returns the number of bits needed to represent an element in the scalar field FieldBitLen() int - // Commit returns a commitment to the given variables, to be used as initial randomness in - // Fiat-Shamir when the statement to prove is particularly large. - // TODO cite paper - // ! Experimental - // TENTATIVE: Functions regarding fiat-shamir-ed proofs over enormous statements TODO finalize - Commit(...Variable) (Variable, error) - SetGkrInfo(constraint.GkrInfo) error + // Defer is called after circuit.Define() and before Compile(). This method + // allows for the circuits to register callbacks which finalize batching + // operations etc. Unlike Go defer, it is not locally scoped. + Defer(cb func(api API) error) + + // InternalVariable returns the internal variable associated with the given wireID + // ! Experimental: use in conjunction with constraint.CustomizableSystem + InternalVariable(wireID uint32) Variable + + // ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable + // ! Experimental: use in conjunction with constraint.CustomizableSystem + ToCanonicalVariable(Variable) CanonicalVariable } // Builder represents a constraint system builder @@ -74,3 +82,27 @@ type Builder interface { // called inside circuit.Define() SecretVariable(schema.LeafInfo) Variable } + +// Committer allows to commit to the variables and returns the commitment. The +// commitment can be used as a challenge using Fiat-Shamir heuristic. +type Committer interface { + // Commit commits to the variables and returns the commitment. + Commit(toCommit ...Variable) (commitment Variable, err error) +} + +// Rangechecker allows to externally range-check the variables to be of +// specified width. Not all compilers implement this interface. Users should +// instead use [github.com/consensys/gnark/std/rangecheck] package which +// automatically chooses most optimal method for range checking the variables. +type Rangechecker interface { + // Check checks that the given variable v has bit-length bits. + Check(v Variable, bits int) +} + +// CanonicalVariable represents a variable that's encoded in a constraint system specific way. +// For example a R1CS builder may represent this as a constraint.LinearExpression, +// a PLONK builder --> constraint.Term +// and the test/Engine --> ~*big.Int. +type CanonicalVariable interface { + constraint.Compressable +} diff --git a/frontend/compile.go b/frontend/compile.go index d39ff5dd7b..072aca1c34 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -9,6 +9,7 @@ import ( "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/circuitdefer" "github.com/consensys/gnark/logger" ) @@ -36,7 +37,7 @@ func Compile(field *big.Int, newBuilder NewBuilder, circuit Circuit, opts ...Com log := logger.Logger() log.Info().Msg("compiling circuit") // parse options - opt := CompileConfig{} + opt := defaultCompileConfig() for _, o := range opts { if err := o(&opt); err != nil { log.Err(err).Msg("applying compile option") @@ -122,15 +123,33 @@ func parseCircuit(builder Builder, circuit Circuit) (err error) { if err = circuit.Define(builder); err != nil { return fmt.Errorf("define circuit: %w", err) } + if err = callDeferred(builder); err != nil { + return fmt.Errorf("deferred: %w", err) + } return } +func callDeferred(builder Builder) error { + for i := 0; i < len(circuitdefer.GetAll[func(API) error](builder)); i++ { + if err := circuitdefer.GetAll[func(API) error](builder)[i](builder); err != nil { + return fmt.Errorf("defer fn %d: %w", i, err) + } + } + return nil +} + // CompileOption defines option for altering the behaviour of the Compile // method. See the descriptions of the functions returning instances of this // type for available options. type CompileOption func(opt *CompileConfig) error +func defaultCompileConfig() CompileConfig { + return CompileConfig{ + CompressThreshold: 300, + } +} + type CompileConfig struct { Capacity int IgnoreUnconstrainedInputs bool @@ -174,6 +193,9 @@ func IgnoreUnconstrainedInputs() CompileOption { // fast. The compression adds some overhead in the number of constraints. The // overhead and compile performance depends on threshold value, and it should be // chosen carefully. +// +// If this option is not given then by default we use the compress threshold of +// 300. func WithCompressThreshold(threshold int) CompileOption { return func(opt *CompileConfig) error { opt.CompressThreshold = threshold diff --git a/frontend/cs/commitment.go b/frontend/cs/commitment.go new file mode 100644 index 0000000000..e8b92478f9 --- /dev/null +++ b/frontend/cs/commitment.go @@ -0,0 +1,52 @@ +package cs + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/logger" + "hash/fnv" + "math/big" + "os" + "strconv" + "strings" + "sync" +) + +func Bsb22CommitmentComputePlaceholder(mod *big.Int, _ []*big.Int, output []*big.Int) (err error) { + if (len(os.Args) > 0 && (strings.HasSuffix(os.Args[0], ".test") || strings.HasSuffix(os.Args[0], ".test.exe"))) || debug.Debug { + // usually we only run solver without prover during testing + log := logger.Logger() + log.Error().Msg("Augmented commitment hint not replaced. Proof will not be sound and verification will fail!") + output[0], err = rand.Int(rand.Reader, mod) + return + } + return fmt.Errorf("placeholder function: to be replaced by commitment computation") +} + +var maxNbCommitments int +var m sync.Mutex + +func RegisterBsb22CommitmentComputePlaceholder(index int) (hintId solver.HintID, err error) { + + hintName := "bsb22 commitment #" + strconv.Itoa(index) + // mimic solver.GetHintID + hf := fnv.New32a() + if _, err = hf.Write([]byte(hintName)); err != nil { + return + } + hintId = solver.HintID(hf.Sum32()) + + m.Lock() + if maxNbCommitments == index { + maxNbCommitments++ + solver.RegisterNamedHint(Bsb22CommitmentComputePlaceholder, hintId) + } + m.Unlock() + + return +} +func init() { + solver.RegisterHint(Bsb22CommitmentComputePlaceholder) +} diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 19eaf80759..7bb01fa76c 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -19,14 +19,17 @@ package r1cs import ( "errors" "fmt" - "math/big" + "github.com/consensys/gnark/internal/utils" "path/filepath" "reflect" "runtime" "strings" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" @@ -55,14 +58,14 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { // v1 and v2 are both unknown, this is the only case we add a constraint if !v1Constant && !v2Constant { res := builder.newInternalVariable() - builder.cs.AddConstraint(builder.newR1C(b, c, res)) + builder.cs.AddR1C(builder.newR1C(b, c, res), builder.genericGate) builder.mbuf1 = append(builder.mbuf1, res...) return } // v1 and v2 are constants, we multiply big.Int values and return resulting constant if v1Constant && v2Constant { - builder.cs.Mul(&n1, &n2) + n1 = builder.cs.Mul(n1, n2) builder.mbuf1 = append(builder.mbuf1, expr.NewTerm(0, n1)) return } @@ -143,9 +146,9 @@ func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int if curr != -1 && t.VID == (*res)[curr].VID { // accumulate, it's the same variable ID if sub && lID != 0 { - builder.cs.Sub(&(*res)[curr].Coeff, &t.Coeff) + (*res)[curr].Coeff = builder.cs.Sub((*res)[curr].Coeff, t.Coeff) } else { - builder.cs.Add(&(*res)[curr].Coeff, &t.Coeff) + (*res)[curr].Coeff = builder.cs.Add((*res)[curr].Coeff, t.Coeff) } if (*res)[curr].Coeff.IsZero() { // remove self. @@ -157,14 +160,14 @@ func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int (*res) = append((*res), *t) curr++ if sub && lID != 0 { - builder.cs.Neg(&(*res)[curr].Coeff) + (*res)[curr].Coeff = builder.cs.Neg((*res)[curr].Coeff) } } } if len((*res)) == 0 { // keep the linear expression valid (assertIsSet) - (*res) = append((*res), expr.NewTerm(0, constraint.Coeff{})) + (*res) = append((*res), expr.NewTerm(0, constraint.Element{})) } // if the linear expression LE is too long then record an equality // constraint LE * 1 = t and return short linear expression instead. @@ -183,7 +186,7 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariable(i) if n, ok := builder.constantValue(v); ok { - builder.cs.Neg(&n) + n = builder.cs.Neg(n) return expr.NewLinearExpression(0, n) } @@ -202,13 +205,13 @@ func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f // v1 and v2 are both unknown, this is the only case we add a constraint if !v1Constant && !v2Constant { res := builder.newInternalVariable() - builder.cs.AddConstraint(builder.newR1C(v1, v2, res)) + builder.cs.AddR1C(builder.newR1C(v1, v2, res), builder.genericGate) return res } // v1 and v2 are constants, we multiply big.Int values and return resulting constant if v1Constant && v2Constant { - builder.cs.Mul(&n1, &n2) + n1 = builder.cs.Mul(n1, n2) return expr.NewLinearExpression(0, n1) } @@ -227,7 +230,7 @@ func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f return res } -func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint.Coeff, inPlace bool) expr.LinearExpression { +func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint.Element, inPlace bool) expr.LinearExpression { // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable var res expr.LinearExpression @@ -238,7 +241,7 @@ func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint. } for i := 0; i < len(res); i++ { - builder.cs.Mul(&res[i].Coeff, &lambda) + res[i].Coeff = builder.cs.Mul(res[i].Coeff, lambda) } return res } @@ -254,9 +257,12 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable if !v2Constant { res := builder.newInternalVariable() - debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) // note that here we don't ensure that divisor is != 0 - builder.cs.AddConstraint(builder.newR1C(v2, res, v1), debug) + cID := builder.cs.AddR1C(builder.newR1C(v2, res, v1), builder.genericGate) + if debug.Debug { + debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) + builder.cs.AttachDebugInfo(debug, []int{cID}) + } return res } @@ -264,12 +270,10 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable if n2.IsZero() { panic("div by constant(0)") } - // q := builder.q - builder.cs.Inverse(&n2) - // n2.ModInverse(n2, q) + n2, _ = builder.cs.Inverse(n2) if v1Constant { - builder.cs.Mul(&n2, &n1) + n2 = builder.cs.Mul(n2, n1) return expr.NewLinearExpression(0, n2) } @@ -292,8 +296,8 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) v2Inv := builder.newInternalVariable() // note that here we ensure that v2 can't be 0, but it costs us one extra constraint - c1 := builder.cs.AddConstraint(builder.newR1C(v2, v2Inv, builder.cstOne())) - c2 := builder.cs.AddConstraint(builder.newR1C(v1, v2Inv, res)) + c1 := builder.cs.AddR1C(builder.newR1C(v2, v2Inv, builder.cstOne()), builder.genericGate) + c2 := builder.cs.AddR1C(builder.newR1C(v1, v2Inv, res), builder.genericGate) builder.cs.AttachDebugInfo(debug, []int{c1, c2}) return res } @@ -302,10 +306,10 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { if n2.IsZero() { panic("div by constant(0)") } - builder.cs.Inverse(&n2) + n2, _ = builder.cs.Inverse(n2) if v1Constant { - builder.cs.Mul(&n2, &n1) + n2 = builder.cs.Mul(n2, n1) return expr.NewLinearExpression(0, n2) } @@ -322,15 +326,18 @@ func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { panic("inverse by constant(0)") } - builder.cs.Inverse(&c) + c, _ = builder.cs.Inverse(c) return expr.NewLinearExpression(0, c) } // allocate resulting frontend.Variable res := builder.newInternalVariable() - debug := builder.newDebugInfo("inverse", vars[0], "*", res, " == 1") - builder.cs.AddConstraint(builder.newR1C(res, vars[0], builder.cstOne()), debug) + cID := builder.cs.AddR1C(builder.newR1C(res, vars[0], builder.cstOne()), builder.genericGate) + if debug.Debug { + debug := builder.newDebugInfo("inverse", vars[0], "*", res, " == 1") + builder.cs.AttachDebugInfo(debug, []int{cID}) + } return res } @@ -406,7 +413,7 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { c = append(c, a...) c = append(c, b...) - builder.cs.AddConstraint(builder.newR1C(a, b, c)) + builder.cs.AddR1C(builder.newR1C(a, b, c), builder.genericGate) return res } @@ -441,7 +448,7 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { if c, ok := builder.constantValue(cond); ok { // condition is a constant return i1 if true, i2 if false - if builder.isCstOne(&c) { + if builder.isCstOne(c) { return vars[1] } return vars[2] @@ -451,7 +458,7 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { n2, ok2 := builder.constantValue(vars[2]) if ok1 && ok2 { - builder.cs.Sub(&n1, &n2) + n1 = builder.cs.Sub(n1, n2) res := builder.Mul(cond, n1) // no constraint is recorded res = builder.Add(res, vars[2]) // no constraint is recorded return res @@ -488,8 +495,8 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten c1, b1IsConstant := builder.constantValue(s1) if b0IsConstant && b1IsConstant { - b0 := builder.isCstOne(&c0) - b1 := builder.isCstOne(&c1) + b0 := builder.isCstOne(c0) + b1 := builder.isCstOne(c1) if !b0 && !b1 { return in0 @@ -543,17 +550,17 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { m := builder.newInternalVariable() // x = 1/a // in a hint (x == 0 if a == 0) - x, err := builder.NewHint(hint.InvZero, 1, a) + x, err := builder.NewHint(solver.InvZeroHint, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) } // m = -a*x + 1 // constrain m to be 1 if a == 0 - c1 := builder.cs.AddConstraint(builder.newR1C(builder.Neg(a), x[0], builder.Sub(m, 1))) + c1 := builder.cs.AddR1C(builder.newR1C(builder.Neg(a), x[0], builder.Sub(m, 1)), builder.genericGate) // a * m = 0 // constrain m to be 0 if a != 0 - c2 := builder.cs.AddConstraint(builder.newR1C(a, m, builder.cstZero())) + c2 := builder.cs.AddR1C(builder.newR1C(a, m, builder.cstZero()), builder.genericGate) builder.cs.AttachDebugInfo(debug, []int{c1, c2}) @@ -660,7 +667,7 @@ func (builder *builder) negateLinExp(l expr.LinearExpression) expr.LinearExpress res := make(expr.LinearExpression, len(l)) copy(res, l) for i := 0; i < len(res); i++ { - builder.cs.Neg(&res[i].Coeff) + res[i].Coeff = builder.cs.Neg(res[i].Coeff) } return res } @@ -670,23 +677,32 @@ func (builder *builder) Compiler() frontend.Compiler { } func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { - // we want to build a sorted slice of commited variables, without duplicates + + commitments := builder.cs.GetCommitments().(constraint.Groth16Commitments) + existingCommitmentIndexes := commitments.CommitmentIndexes() + privateCommittedSeeker := utils.MultiListSeeker(commitments.GetPrivateCommitted()) + + // we want to build a sorted slice of committed variables, without duplicates // this is the same algorithm as builder.add(...); but we expect len(v) to be quite large. vars, s := builder.toVariables(v...) + nbPublicCommitted := 0 // initialize the min-heap // this is the same algorithm as api.add --> we want to merge k sorted linear expression for lID, v := range vars { - builder.heap = append(builder.heap, linMeta{val: v[0].VID, lID: lID}) + if v[0].VID < builder.cs.GetNbPublicVariables() { + nbPublicCommitted++ + } + builder.heap = append(builder.heap, linMeta{val: v[0].VID, lID: lID}) // TODO: Use int heap } builder.heap.heapify() // sort all the wires - committed := make([]int, 0, s) - - curr := -1 - nbPublicCommitted := 0 + publicAndCommitmentCommitted := make([]int, 0, nbPublicCommitted+len(existingCommitmentIndexes)) // right now nbPublicCommitted is an upper bound + privateCommitted := make([]int, 0, s) + lastInsertedWireId := -1 + nbPublicCommitted = 0 // process all the terms from all the inputs, in sorted order for len(builder.heap) > 0 { @@ -704,43 +720,79 @@ func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error if t.VID == 0 { continue // don't commit to ONE_WIRE } - if curr != -1 && t.VID == committed[curr] { + if lastInsertedWireId == t.VID { // it's the same variable ID, do nothing continue - } else { - // append, it's a new variable ID - committed = append(committed, t.VID) - if t.VID < builder.cs.GetNbPublicVariables() { - nbPublicCommitted++ - } - curr++ } + + if t.VID < builder.cs.GetNbPublicVariables() { // public + publicAndCommitmentCommitted = append(publicAndCommitmentCommitted, t.VID) + lastInsertedWireId = t.VID + nbPublicCommitted++ + continue + } + + // private or commitment + for len(existingCommitmentIndexes) > 0 && existingCommitmentIndexes[0] < t.VID { + existingCommitmentIndexes = existingCommitmentIndexes[1:] + } + if len(existingCommitmentIndexes) > 0 && existingCommitmentIndexes[0] == t.VID { // commitment + publicAndCommitmentCommitted = append(publicAndCommitmentCommitted, t.VID) + existingCommitmentIndexes = existingCommitmentIndexes[1:] // technically unnecessary + lastInsertedWireId = t.VID + continue + } + + // private + // Cannot commit to a secret variable that has already been committed to + // instead we commit to its commitment + if committer := privateCommittedSeeker.Seek(t.VID); committer != -1 { + committerWireIndex := existingCommitmentIndexes[committer] // commit to this commitment instead + vars = append(vars, expr.LinearExpression{{Coeff: constraint.Element{1}, VID: committerWireIndex}}) // TODO Replace with mont 1 + builder.heap.push(linMeta{lID: len(vars) - 1, tID: 0, val: committerWireIndex}) // pushing to heap mid-op is okay because toCommit > t.VID > anything popped so far + continue + } + + // so it's a new, so-far-uncommitted private variable + privateCommitted = append(privateCommitted, t.VID) + lastInsertedWireId = t.VID } - if len(committed) == 0 { + if len(privateCommitted)+len(publicAndCommitmentCommitted) == 0 { // TODO @tabaie Necessary? return nil, errors.New("must commit to at least one variable") } // build commitment - commitment := constraint.NewCommitment(committed, nbPublicCommitted) + commitment := constraint.Groth16Commitment{ + PublicAndCommitmentCommitted: publicAndCommitmentCommitted, + NbPublicCommitted: nbPublicCommitted, + PrivateCommitted: privateCommitted, + } // hint is used at solving time to compute the actual value of the commitment // it is going to be dynamically replaced at solving time. - hintOut, err := builder.NewHint(bsb22CommitmentComputePlaceholder, 1, builder.getCommittedVariables(&commitment)...) + + var ( + hintOut []frontend.Variable + err error + ) + + commitment.HintID, err = cs.RegisterBsb22CommitmentComputePlaceholder(len(commitments)) if err != nil { return nil, err } + + if hintOut, err = builder.NewHintForId(commitment.HintID, 1, builder.wireIDsToVars( + commitment.PublicAndCommitmentCommitted, + commitment.PrivateCommitted, + )...); err != nil { + return nil, err + } + cVar := hintOut[0] - commitment.HintID = hint.UUID(bsb22CommitmentComputePlaceholder) // TODO @gbotrel probably not needed commitment.CommitmentIndex = (cVar.(expr.LinearExpression))[0].WireID() - // TODO @Tabaie: Get rid of this field - commitment.CommittedAndCommitment = append(commitment.Committed, commitment.CommitmentIndex) - if commitment.CommitmentIndex <= commitment.Committed[len(commitment.Committed)-1] { - return nil, fmt.Errorf("commitment variable index smaller than some committed variable indices") - } - if err := builder.cs.AddCommitment(commitment); err != nil { return nil, err } @@ -748,18 +800,18 @@ func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error return cVar, nil } -func (builder *builder) getCommittedVariables(i *constraint.Commitment) []frontend.Variable { - res := make([]frontend.Variable, len(i.Committed)) - for j, wireIndex := range i.Committed { - res[j] = expr.NewLinearExpression(wireIndex, builder.tOne) +func (builder *builder) wireIDsToVars(wireIDs ...[]int) []frontend.Variable { + n := 0 + for i := range wireIDs { + n += len(wireIDs[i]) + } + res := make([]frontend.Variable, n) + n = 0 + for _, list := range wireIDs { + for i := range list { + res[n+i] = expr.NewLinearExpression(list[i], builder.tOne) + } + n += len(list) } return res } - -func (builder *builder) SetGkrInfo(info constraint.GkrInfo) error { - return builder.cs.AddGkr(info) -} - -func bsb22CommitmentComputePlaceholder(*big.Int, []*big.Int, []*big.Int) error { - return fmt.Errorf("placeholder function: to be replaced by commitment computation") -} diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index 2d2db8f010..3c249f4845 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -17,12 +17,12 @@ limitations under the License. package r1cs import ( + "fmt" "math/big" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/math/bits" ) @@ -32,14 +32,22 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { r := builder.getLinearExpression(builder.toVariable(i1)) o := builder.getLinearExpression(builder.toVariable(i2)) - debug := builder.newDebugInfo("assertIsEqual", r, " == ", o) + cID := builder.cs.AddR1C(builder.newR1C(builder.cstOne(), r, o), builder.genericGate) - builder.cs.AddConstraint(builder.newR1C(builder.cstOne(), r, o), debug) + if debug.Debug { + debug := builder.newDebugInfo("assertIsEqual", r, " == ", o) + builder.cs.AttachDebugInfo(debug, []int{cID}) + } } // AssertIsDifferent constrain i1 and i2 to be different func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { - builder.Inverse(builder.Sub(i1, i2)) + s := builder.Sub(i1, i2).(expr.LinearExpression) + if len(s) == 1 && s[0].Coeff.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } + + builder.Inverse(s) } // AssertIsBoolean adds an assertion in the constraint builder (v == 0 ∥ v == 1) @@ -48,7 +56,7 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { v := builder.toVariable(i1) if b, ok := builder.constantValue(v); ok { - if !(builder.isCstZero(&b) || builder.isCstOne(&b)) { + if !(b.IsZero() || builder.isCstOne(b)) { panic("assertIsBoolean failed: constant is not 0 or 1") // TODO @gbotrel print } return @@ -66,11 +74,10 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { V := builder.getLinearExpression(v) + cID := builder.cs.AddR1C(builder.newR1C(V, _v, o), builder.genericGate) if debug.Debug { debug := builder.newDebugInfo("assertIsBoolean", V, " == (0|1)") - builder.cs.AddConstraint(builder.newR1C(V, _v, o), debug) - } else { - builder.cs.AddConstraint(builder.newR1C(V, _v, o)) + builder.cs.AttachDebugInfo(debug, []int{cID}) } } @@ -80,18 +87,34 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { // // derived from: // https://github.com/zcash/zips/blob/main/protocol/protocol.pdf -func (builder *builder) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { - v := builder.toVariable(_v) - - if b, ok := bound.(expr.LinearExpression); ok { - assertIsSet(b) - builder.mustBeLessOrEqVar(v, b) - } else { - builder.mustBeLessOrEqCst(v, utils.FromInterface(bound)) +func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { + cv, vConst := builder.constantValue(v) + cb, bConst := builder.constantValue(bound) + + // both inputs are constants + if vConst && bConst { + bv, bb := builder.cs.ToBigInt(cv), builder.cs.ToBigInt(cb) + if bv.Cmp(bb) == 1 { + panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", bv.String(), bb.String())) + } } + + // bound is constant + if bConst { + vv := builder.toVariable(v) + builder.mustBeLessOrEqCst(vv, builder.cs.ToBigInt(cb)) + return + } + + builder.mustBeLessOrEqVar(v, bound) } -func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) { +func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { + // here bound is NOT a constant, + // but a can be either constant or a wire. + + _, aConst := builder.constantValue(a) + debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) nbBits := builder.cs.FieldBitLen() @@ -128,16 +151,22 @@ func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i].(expr.LinearExpression)) // this does not create a constraint - added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], zero))) + if aConst { + // aBits[i] is a constant; + l = builder.Mul(l, aBits[i]) + // TODO @gbotrel this constraint seems useless. + added = append(added, builder.cs.AddR1C(builder.newR1C(l, zero, zero), builder.genericGate)) + } else { + added = append(added, builder.cs.AddR1C(builder.newR1C(l, aBits[i], zero), builder.genericGate)) + } } builder.cs.AttachDebugInfo(debug, added) } -func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound big.Int) { +func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.Int) { nbBits := builder.cs.FieldBitLen() @@ -186,8 +215,7 @@ func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound big.Int l := builder.Sub(1, p[i+1]) l = builder.Sub(l, aBits[i]) - added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], builder.cstZero()))) - builder.MarkBoolean(aBits[i].(expr.LinearExpression)) + added = append(added, builder.cs.AddR1C(builder.newR1C(l, aBits[i], builder.cstZero()), builder.genericGate)) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index d06bd9c3d9..1a27261bda 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -23,12 +23,14 @@ import ( "sort" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/frontendtype" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -40,29 +42,36 @@ import ( bn254r1cs "github.com/consensys/gnark/constraint/bn254" bw6633r1cs "github.com/consensys/gnark/constraint/bw6-633" bw6761r1cs "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/constraint/solver" tinyfieldr1cs "github.com/consensys/gnark/constraint/tinyfield" ) // NewBuilder returns a new R1CS builder which implements frontend.API. +// Additionally, this builder also implements [frontend.Committer]. func NewBuilder(field *big.Int, config frontend.CompileConfig) (frontend.Builder, error) { return newBuilder(field, config), nil } type builder struct { - cs constraint.R1CS - + cs constraint.R1CS config frontend.CompileConfig + kvstore.Store // map for recording boolean constrained variables (to not constrain them twice) mtBooleans map[uint64][]expr.LinearExpression - q *big.Int - tOne constraint.Coeff - heap minHeap // helps merge k sorted linear expressions + tOne constraint.Element + eZero, eOne expr.LinearExpression + cZero, cOne constraint.LinearExpression + + // helps merge k sorted linear expressions + heap minHeap // buffers used to do in place api.MAC mbuf1 expr.LinearExpression mbuf2 expr.LinearExpression + + genericGate constraint.BlueprintID } // initialCapacity has quite some impact on frontend performance, especially on large circuits size @@ -78,6 +87,7 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { heap: make(minHeap, 0, 100), mbuf1: make(expr.LinearExpression, 0, macCapacity), mbuf2: make(expr.LinearExpression, 0, macCapacity), + Store: kvstore.New(), } // by default the circuit is given a public wire equal to 1 @@ -110,10 +120,13 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { builder.tOne = builder.cs.One() builder.cs.AddPublicVariable("1") - builder.q = builder.cs.Field() - if builder.q.Cmp(field) != 0 { - panic("invalid modulus on cs impl") // sanity check - } + builder.genericGate = builder.cs.AddBlueprint(&constraint.BlueprintGenericR1C{}) + + builder.eZero = expr.NewLinearExpression(0, constraint.Element{}) + builder.eOne = expr.NewLinearExpression(0, builder.tOne) + + builder.cOne = constraint.LinearExpression{constraint.Term{VID: 0, CID: constraint.CoeffIdOne}} + builder.cZero = constraint.LinearExpression{constraint.Term{VID: 0, CID: constraint.CoeffIdZero}} return &builder } @@ -139,19 +152,15 @@ func (builder *builder) SecretVariable(f schema.LeafInfo) frontend.Variable { // cstOne return the one constant func (builder *builder) cstOne() expr.LinearExpression { - return expr.NewLinearExpression(0, builder.tOne) + return builder.eOne } // cstZero return the zero constant func (builder *builder) cstZero() expr.LinearExpression { - return expr.NewLinearExpression(0, constraint.Coeff{}) -} - -func (builder *builder) isCstZero(c *constraint.Coeff) bool { - return c.IsZero() + return builder.eZero } -func (builder *builder) isCstOne(c *constraint.Coeff) bool { +func (builder *builder) isCstOne(c constraint.Element) bool { return builder.cs.IsOne(c) } @@ -188,9 +197,16 @@ func (builder *builder) getLinearExpression(_l interface{}) constraint.LinearExp var L constraint.LinearExpression switch tl := _l.(type) { case expr.LinearExpression: + if len(tl) == 1 && tl[0].VID == 0 { + if tl[0].Coeff.IsZero() { + return builder.cZero + } else if tl[0].Coeff == builder.tOne { + return builder.cOne + } + } L = make(constraint.LinearExpression, 0, len(tl)) for _, t := range tl { - L = append(L, builder.cs.MakeTerm(&t.Coeff, t.VID)) + L = append(L, builder.cs.MakeTerm(t.Coeff, t.VID)) } case constraint.LinearExpression: L = tl @@ -208,7 +224,7 @@ func (builder *builder) getLinearExpression(_l interface{}) constraint.LinearExp // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. func (builder *builder) MarkBoolean(v frontend.Variable) { if b, ok := builder.constantValue(v); ok { - if !(builder.isCstZero(&b) || builder.isCstOne(&b)) { + if !(b.IsZero() || builder.isCstOne(b)) { panic("MarkBoolean called a non-boolean constant") } return @@ -228,7 +244,7 @@ func (builder *builder) MarkBoolean(v frontend.Variable) { // This returns true if the v is a constant and v == 0 || v == 1. func (builder *builder) IsBoolean(v frontend.Variable) bool { if b, ok := builder.constantValue(v); ok { - return (builder.isCstZero(&b) || builder.isCstOne(&b)) + return (b.IsZero() || builder.isCstOne(b)) } // v is a linear expression l := v.(expr.LinearExpression) @@ -280,20 +296,20 @@ func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { if !ok { return nil, false } - return builder.cs.ToBigInt(&coeff), true + return builder.cs.ToBigInt(coeff), true } -func (builder *builder) constantValue(v frontend.Variable) (constraint.Coeff, bool) { +func (builder *builder) constantValue(v frontend.Variable) (constraint.Element, bool) { if _v, ok := v.(expr.LinearExpression); ok { assertIsSet(_v) if len(_v) != 1 { // TODO @gbotrel this assumes linear expressions of coeff are not possible // and are always reduced to one element. may not always be true? - return constraint.Coeff{}, false + return constraint.Element{}, false } if !(_v[0].WireID() == 0) { // public ONE WIRE - return constraint.Coeff{}, false + return constraint.Element{}, false } return _v[0].Coeff, true } @@ -314,9 +330,9 @@ func (builder *builder) toVariable(input interface{}) expr.LinearExpression { case *expr.LinearExpression: assertIsSet(*t) return *t - case constraint.Coeff: + case constraint.Element: return expr.NewLinearExpression(0, t) - case *constraint.Coeff: + case *constraint.Element: return expr.NewLinearExpression(0, *t) default: // try to make it into a constant @@ -354,7 +370,15 @@ func (builder *builder) toVariables(in ...frontend.Variable) ([]expr.LinearExpre // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (builder *builder) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(f, solver.GetHintID(f), nbOutputs, inputs) +} + +func (builder *builder) NewHintForId(id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(nil, id, nbOutputs, inputs) +} + +func (builder *builder) newHint(f solver.Hint, id solver.HintID, nbOutputs int, inputs []frontend.Variable) ([]frontend.Variable, error) { hintInputs := make([]constraint.LinearExpression, len(inputs)) // TODO @gbotrel hint input pass @@ -365,13 +389,13 @@ func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...fronte hintInputs[i] = builder.getLinearExpression(t) } else { c := builder.cs.FromInterface(in) - term := builder.cs.MakeTerm(&c, 0) + term := builder.cs.MakeTerm(c, 0) term.MarkConstant() hintInputs[i] = constraint.LinearExpression{term} } } - internalVariables, err := builder.cs.AddSolverHint(f, hintInputs, nbOutputs) + internalVariables, err := builder.cs.AddSolverHint(f, id, hintInputs, nbOutputs) if err != nil { return nil, err } @@ -382,7 +406,6 @@ func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...fronte res[i] = expr.NewLinearExpression(idx, builder.tOne) } return res, nil - } // assertIsSet panics if the variable is unset @@ -423,10 +446,10 @@ func (builder *builder) newDebugInfo(errName string, in ...interface{}) constrai in[i] = builder.getLinearExpression(expr.LinearExpression{t}) case *expr.Term: in[i] = builder.getLinearExpression(expr.LinearExpression{*t}) - case constraint.Coeff: - in[i] = builder.cs.String(&t) - case *constraint.Coeff: + case constraint.Element: in[i] = builder.cs.String(t) + case *constraint.Element: + in[i] = builder.cs.String(*t) } } @@ -445,6 +468,42 @@ func (builder *builder) compress(le expr.LinearExpression) expr.LinearExpression one := builder.cstOne() t := builder.newInternalVariable() - builder.cs.AddConstraint(builder.newR1C(le, one, t)) + builder.cs.AddR1C(builder.newR1C(le, one, t), builder.genericGate) return t } + +func (builder *builder) Defer(cb func(frontend.API) error) { + circuitdefer.Put(builder, cb) +} + +func (*builder) FrontendType() frontendtype.Type { + return frontendtype.R1CS +} + +// AddInstruction is used to add custom instructions to the constraint system. +func (builder *builder) AddInstruction(bID constraint.BlueprintID, calldata []uint32) []uint32 { + return builder.cs.AddInstruction(bID, calldata) +} + +// AddBlueprint adds a custom blueprint to the constraint system. +func (builder *builder) AddBlueprint(b constraint.Blueprint) constraint.BlueprintID { + return builder.cs.AddBlueprint(b) +} + +func (builder *builder) InternalVariable(wireID uint32) frontend.Variable { + return expr.NewLinearExpression(int(wireID), builder.tOne) +} + +// ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable +// ! Experimental: use in conjunction with constraint.CustomizableSystem +func (builder *builder) ToCanonicalVariable(in frontend.Variable) frontend.CanonicalVariable { + if t, ok := in.(expr.LinearExpression); ok { + assertIsSet(t) + return builder.getLinearExpression(t) + } else { + c := builder.cs.FromInterface(in) + term := builder.cs.MakeTerm(c, 0) + term.MarkConstant() + return constraint.LinearExpression{term} + } +} diff --git a/frontend/cs/r1cs/heap.go b/frontend/cs/r1cs/heap.go index a0d6035411..07c6bc436e 100644 --- a/frontend/cs/r1cs/heap.go +++ b/frontend/cs/r1cs/heap.go @@ -21,7 +21,7 @@ func (h *minHeap) heapify() { } } -// push pushes the element x onto the heap. +// push the element x onto the heap. // The complexity is O(log n) where n = len(*h). func (h *minHeap) push(x linMeta) { *h = append(*h, x) diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index 51370f9802..ceb965322c 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -99,7 +99,7 @@ func BenchmarkReduce(b *testing.B) { // Add few large linear expressions // Add many large linear expressions // Doubling of large linear expressions - rand.Seed(time.Now().Unix()) + rand := rand.New(rand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here const nbTerms = 100000 terms := make([]frontend.Variable, nbTerms) for i := 0; i < len(terms); i++ { @@ -135,3 +135,28 @@ func BenchmarkReduce(b *testing.B) { } }) } + +type EmptyCircuit struct { + X frontend.Variable + cb func(frontend.API) error +} + +func (c *EmptyCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.X, 0) + api.Compiler().Defer(c.cb) + return nil +} + +func TestPreCompileHook(t *testing.T) { + var called bool + c := &EmptyCircuit{ + cb: func(a frontend.API) error { called = true; return nil }, + } + _, err := frontend.Compile(ecc.BN254.ScalarField(), NewBuilder, c) + if err != nil { + t.Fatal(err) + } + if !called { + t.Error("callback not called") + } +} diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 0971715911..12eff1e730 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -18,49 +18,48 @@ package scs import ( "fmt" - "math/big" "path/filepath" "reflect" "runtime" "strings" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/frontendtype" "github.com/consensys/gnark/std/math/bits" ) // Add returns res = i1+i2+...in -func (builder *scs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - zero := big.NewInt(0) +func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + // separate the constant part from the variables vars, k := builder.filterConstantSum(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { - return k + // no variables, we return the constant. + return builder.cs.ToBigInt(k) } + vars = builder.reduce(vars) - if k.Cmp(zero) == 0 { - return builder.splitSum(vars[0], vars[1:]) + if k.IsZero() { + return builder.splitSum(vars[0], vars[1:], nil) } - cl, _ := vars[0].Unpack() - kID := builder.st.CoeffID(&k) - o := builder.newInternalVariable() - builder.addPlonkConstraint(vars[0], builder.zero(), o, cl, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, kID) - return builder.splitSum(o, vars[1:]) - + // no constant we decompose the linear expressions in additions of 2 terms + return builder.splitSum(vars[0], vars[1:], &k) } -func (builder *scs) MulAcc(a, b, c frontend.Variable) frontend.Variable { +func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { // TODO can we do better here to limit allocations? - // technically we could do that in one PlonK constraint (against 2 for separate Add & Mul) return builder.Add(a, builder.Mul(b, c)) } // neg returns -in -func (builder *scs) neg(in []frontend.Variable) []frontend.Variable { - +func (builder *builder) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) for i := 0; i < len(in); i++ { @@ -70,87 +69,79 @@ func (builder *scs) neg(in []frontend.Variable) []frontend.Variable { } // Sub returns res = i1 - i2 - ...in -func (builder *scs) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { r := builder.neg(append([]frontend.Variable{i2}, in...)) return builder.Add(i1, r[0], r[1:]...) } // Neg returns -i -func (builder *scs) Neg(i1 frontend.Variable) frontend.Variable { - if n, ok := builder.ConstantValue(i1); ok { - n.Neg(n) - return *n - } else { - v := i1.(expr.TermToRefactor) - c, _ := v.Unpack() - var coef big.Int - coef.Set(&builder.st.Coeffs[c]) - coef.Neg(&coef) - c = builder.st.CoeffID(&coef) - v.SetCoeffID(c) - return v +func (builder *builder) Neg(i1 frontend.Variable) frontend.Variable { + if n, ok := builder.constantValue(i1); ok { + n = builder.cs.Neg(n) + return builder.cs.ToBigInt(n) } + v := i1.(expr.Term) + v.Coeff = builder.cs.Neg(v.Coeff) + return v } // Mul returns res = i1 * i2 * ... in -func (builder *scs) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - +func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, k := builder.filterConstantProd(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { - return k + return builder.cs.ToBigInt(k) } - l := builder.mulConstant(vars[0], &k) - return builder.splitProd(l, vars[1:]) + l := builder.mulConstant(vars[0], k) + return builder.splitProd(l, vars[1:]) } // returns t*m -func (builder *scs) mulConstant(t expr.TermToRefactor, m *big.Int) expr.TermToRefactor { - var coef big.Int - cid, _ := t.Unpack() - coef.Set(&builder.st.Coeffs[cid]) - coef.Mul(m, &coef).Mod(&coef, builder.q) - cid = builder.st.CoeffID(&coef) - t.SetCoeffID(cid) +func (builder *builder) mulConstant(t expr.Term, m constraint.Element) expr.Term { + t.Coeff = builder.cs.Mul(t.Coeff, m) return t } // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 -func (builder *scs) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - c1, i1Constant := builder.ConstantValue(i1) - c2, i2Constant := builder.ConstantValue(i2) +func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { + c1, i1Constant := builder.constantValue(i1) + c2, i2Constant := builder.constantValue(i2) if i1Constant && i2Constant { - l := c1 - r := c2 - q := builder.q - return r.ModInverse(r, q). - Mul(l, r). - Mod(r, q) + if c2.IsZero() { + panic("inverse by constant(0)") + } + c2, _ = builder.cs.Inverse(c2) + c2 = builder.cs.Mul(c2, c1) + return builder.cs.ToBigInt(c2) } if i2Constant { - c := c2 - q := builder.q - c.ModInverse(c, q) - return builder.mulConstant(i1.(expr.TermToRefactor), c) + if c2.IsZero() { + panic("inverse by constant(0)") + } + c2, _ = builder.cs.Inverse(c2) + return builder.mulConstant(i1.(expr.Term), c2) } if i1Constant { res := builder.Inverse(i2) - return builder.mulConstant(res.(expr.TermToRefactor), c1) + return builder.mulConstant(res.(expr.Term), c1) } + // res * i2 == i1 res := builder.newInternalVariable() - r := i2.(expr.TermToRefactor) - o := builder.Neg(i1).(expr.TermToRefactor) - cr, _ := r.Unpack() - co, _ := o.Unpack() - builder.addPlonkConstraint(res, r, o, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, cr, co, constraint.CoeffIdZero) + builder.addPlonkConstraint(sparseR1C{ + xa: res.VID, + xb: i2.(expr.Term).VID, + xc: i1.(expr.Term).VID, + qM: i2.(expr.Term).Coeff, + qO: builder.cs.Neg(i1.(expr.Term).Coeff), + }) + return res } // Div returns i1 / i2 -func (builder *scs) Div(i1, i2 frontend.Variable) frontend.Variable { - +func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { // note that here we ensure that v2 can't be 0, but it costs us one extra constraint builder.Inverse(i2) @@ -158,16 +149,32 @@ func (builder *scs) Div(i1, i2 frontend.Variable) frontend.Variable { } // Inverse returns res = 1 / i1 -func (builder *scs) Inverse(i1 frontend.Variable) frontend.Variable { - if c, ok := builder.ConstantValue(i1); ok { - c.ModInverse(c, builder.q) - return c +func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { + if c, ok := builder.constantValue(i1); ok { + if c.IsZero() { + panic("inverse by constant(0)") + } + c, _ = builder.cs.Inverse(c) + return builder.cs.ToBigInt(c) } - t := i1.(expr.TermToRefactor) - cr, _ := t.Unpack() - debug := builder.newDebugInfo("inverse", "1/", i1, " < ∞") + t := i1.(expr.Term) res := builder.newInternalVariable() - builder.addPlonkConstraint(res, t, builder.zero(), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, cr, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, debug) + + // res * i1 - 1 == 0 + constraint := sparseR1C{ + xa: res.VID, + xb: t.VID, + qM: t.Coeff, + qC: builder.tMinusOne, + } + + if debug.Debug { + debug := builder.newDebugInfo("inverse", "1/", i1, " < ∞") + builder.addPlonkConstraint(constraint, debug) + } else { + builder.addPlonkConstraint(constraint) + } + return res } @@ -179,7 +186,7 @@ func (builder *scs) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (builder *scs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := builder.cs.FieldBitLen() if len(n) == 1 { @@ -193,85 +200,151 @@ func (builder *scs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable } // FromBinary packs b, seen as a fr.Element in little endian -func (builder *scs) FromBinary(b ...frontend.Variable) frontend.Variable { +func (builder *builder) FromBinary(b ...frontend.Variable) frontend.Variable { return bits.FromBinary(builder, b) } // Xor returns a ^ b // a and b must be 0 or 1 -func (builder *scs) Xor(a, b frontend.Variable) frontend.Variable { - +func (builder *builder) Xor(a, b frontend.Variable) frontend.Variable { + // pre condition: a, b must be booleans builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) - _a, aConstant := builder.ConstantValue(a) - _b, bConstant := builder.ConstantValue(b) + _a, aConstant := builder.constantValue(a) + _b, bConstant := builder.constantValue(b) + + // if both inputs are constants if aConstant && bConstant { - _a.Xor(_a, _b) - return _a + b0 := 0 + b1 := 0 + if builder.cs.IsOne(_a) { + b0 = 1 + } + if builder.cs.IsOne(_b) { + b1 = 1 + } + return b0 ^ b1 } res := builder.newInternalVariable() builder.MarkBoolean(res) + + // if one input is constant, ensure we put it in b. if aConstant { a, b = b, a bConstant = aConstant _b = _a } if bConstant { - l := a.(expr.TermToRefactor) - r := l - oneMinusTwoB := big.NewInt(1) - oneMinusTwoB.Sub(oneMinusTwoB, _b).Sub(oneMinusTwoB, _b) - builder.addPlonkConstraint(l, r, res, builder.st.CoeffID(oneMinusTwoB), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, builder.st.CoeffID(_b)) + xa := a.(expr.Term) + // 1 - 2b + qL := builder.tOne + qL = builder.cs.Sub(qL, _b) + qL = builder.cs.Sub(qL, _b) + qL = builder.cs.Mul(qL, xa.Coeff) + + // (1-2b)a + b == res + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xc: res.VID, + qL: qL, + qO: builder.tMinusOne, + qC: _b, + }) + // builder.addPlonkConstraint(xa, xb, res, builder.st.CoeffID(oneMinusTwoB), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, builder.st.CoeffID(_b)) return res } - l := a.(expr.TermToRefactor) - r := b.(expr.TermToRefactor) - builder.addPlonkConstraint(l, r, res, constraint.CoeffIdMinusOne, constraint.CoeffIdMinusOne, constraint.CoeffIdTwo, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero) + xa := a.(expr.Term) + xb := b.(expr.Term) + + // -a - b + 2ab + res == 0 + qM := builder.tOne + qM = builder.cs.Add(qM, qM) + qM = builder.cs.Mul(qM, xa.Coeff) + qM = builder.cs.Mul(qM, xb.Coeff) + + qL := builder.cs.Neg(xa.Coeff) + qR := builder.cs.Neg(xb.Coeff) + + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xb: xb.VID, + xc: res.VID, + qL: qL, + qR: qR, + qO: builder.tOne, + qM: qM, + }) + // builder.addPlonkConstraint(xa, xb, res, constraint.CoeffIdMinusOne, constraint.CoeffIdMinusOne, constraint.CoeffIdTwo, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero) return res } // Or returns a | b // a and b must be 0 or 1 -func (builder *scs) Or(a, b frontend.Variable) frontend.Variable { - +func (builder *builder) Or(a, b frontend.Variable) frontend.Variable { builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) - _a, aConstant := builder.ConstantValue(a) - _b, bConstant := builder.ConstantValue(b) + _a, aConstant := builder.constantValue(a) + _b, bConstant := builder.constantValue(b) if aConstant && bConstant { - _a.Or(_a, _b) - return _a + if builder.cs.IsOne(_a) || builder.cs.IsOne(_b) { + return 1 + } + return 0 } + res := builder.newInternalVariable() builder.MarkBoolean(res) + + // if one input is constant, ensure we put it in b if aConstant { a, b = b, a _b = _a bConstant = aConstant } - if bConstant { - l := a.(expr.TermToRefactor) - r := l - one := big.NewInt(1) - _b.Sub(_b, one) - idl := builder.st.CoeffID(_b) - builder.addPlonkConstraint(l, r, res, idl, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, constraint.CoeffIdZero) + if bConstant { + xa := a.(expr.Term) + // b = b - 1 + qL := _b + qL = builder.cs.Sub(qL, builder.tOne) + qL = builder.cs.Mul(qL, xa.Coeff) + // a * (b-1) + res == 0 + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xc: res.VID, + qL: qL, + qO: builder.tOne, + }) return res } - l := a.(expr.TermToRefactor) - r := b.(expr.TermToRefactor) - builder.addPlonkConstraint(l, r, res, constraint.CoeffIdMinusOne, constraint.CoeffIdMinusOne, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero) + xa := a.(expr.Term) + xb := b.(expr.Term) + // -a - b + ab + res == 0 + + qM := builder.cs.Mul(xa.Coeff, xb.Coeff) + + qL := builder.cs.Neg(xa.Coeff) + qR := builder.cs.Neg(xb.Coeff) + + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xb: xb.VID, + xc: res.VID, + qL: qL, + qR: qR, + qM: qM, + qO: builder.tOne, + }) return res } // Or returns a & b // a and b must be 0 or 1 -func (builder *scs) And(a, b frontend.Variable) frontend.Variable { +func (builder *builder) And(a, b frontend.Variable) frontend.Variable { builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) res := builder.Mul(a, b) @@ -283,14 +356,14 @@ func (builder *scs) And(a, b frontend.Variable) frontend.Variable { // Conditionals // Select if b is true, yields i1 else yields i2 -func (builder *scs) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { - _b, bConstant := builder.ConstantValue(b) +func (builder *builder) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { + _b, bConstant := builder.constantValue(b) if bConstant { - if !(_b.IsUint64() && (_b.Uint64() <= 1)) { - panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) + if !builder.IsBoolean(b) { + panic(fmt.Sprintf("%s should be 0 or 1", builder.cs.String(_b))) } - if _b.Uint64() == 0 { + if _b.IsZero() { return i2 } return i1 @@ -305,23 +378,18 @@ func (builder *scs) Select(b frontend.Variable, i1, i2 frontend.Variable) fronte // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (builder *scs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { - - // vars, _ := builder.toVariables(b0, b1, i0, i1, i2, i3) - // s0, s1 := vars[0], vars[1] - // in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] - +func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { // ensure that bits are actually bits. Adds no constraints if the variables // are already constrained. builder.AssertIsBoolean(b0) builder.AssertIsBoolean(b1) - c0, b0IsConstant := builder.ConstantValue(b0) - c1, b1IsConstant := builder.ConstantValue(b1) + c0, b0IsConstant := builder.constantValue(b0) + c1, b1IsConstant := builder.constantValue(b1) if b0IsConstant && b1IsConstant { - b0 := c0.Uint64() == 1 - b1 := c1.Uint64() == 1 + b0 := builder.cs.IsOne(c0) + b1 := builder.cs.IsOne(c1) if !b0 && !b1 { return i0 @@ -360,22 +428,22 @@ func (builder *scs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Va } // IsZero returns 1 if a is zero, 0 otherwise -func (builder *scs) IsZero(i1 frontend.Variable) frontend.Variable { - if a, ok := builder.ConstantValue(i1); ok { - if !(a.IsUint64() && a.Uint64() == 0) { - return 0 +func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { + if a, ok := builder.constantValue(i1); ok { + if a.IsZero() { + return 1 } - return 1 + return 0 } // x = 1/a // in a hint (x == 0 if a == 0) // m = -a*x + 1 // constrain m to be 1 if a == 0 // a * m = 0 // constrain m to be 0 if a != 0 - a := i1.(expr.TermToRefactor) + a := i1.(expr.Term) m := builder.newInternalVariable() // x = 1/a // in a hint (x == 0 if a == 0) - x, err := builder.NewHint(hint.InvZero, 1, a) + x, err := builder.NewHint(solver.InvZeroHint, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) @@ -383,24 +451,28 @@ func (builder *scs) IsZero(i1 frontend.Variable) frontend.Variable { // m = -a*x + 1 // constrain m to be 1 if a == 0 // a*x + m - 1 == 0 - builder.addPlonkConstraint(a, - x[0].(expr.TermToRefactor), - m, - constraint.CoeffIdZero, - constraint.CoeffIdZero, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdMinusOne) + X := x[0].(expr.Term) + builder.addPlonkConstraint(sparseR1C{ + xa: a.VID, + xb: X.VID, + xc: m.VID, + qM: a.Coeff, + qO: builder.tOne, + qC: builder.tMinusOne, + }) // a * m = 0 // constrain m to be 0 if a != 0 - builder.addPlonkConstraint(a, m, builder.zero(), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero, constraint.CoeffIdZero) + builder.addPlonkConstraint(sparseR1C{ + xa: a.VID, + xb: m.VID, + qM: a.Coeff, + }) return m } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1= 0; i-- { - iszeroi1 := builder.IsZero(bi1[i]) iszeroi2 := builder.IsZero(bi2[i]) @@ -420,7 +491,6 @@ func (builder *scs) Cmp(i1, i2 frontend.Variable) frontend.Variable { m := builder.Select(i1i2, 1, n) res = builder.Select(builder.IsZero(res), m, res) - } return res } @@ -431,8 +501,8 @@ func (builder *scs) Cmp(i1, i2 frontend.Variable) frontend.Variable { // // the print will be done once the R1CS.Solve() method is executed // -// if one of the input is a variable, its value will be resolved avec R1CS.Solve() method is called -func (builder *scs) Println(a ...frontend.Variable) { +// if one of the input is a variable, its value will be resolved when R1CS.Solve() method is called +func (builder *builder) Println(a ...frontend.Variable) { var log constraint.LogEntry // prefix log line with file.go:line @@ -446,12 +516,12 @@ func (builder *scs) Println(a ...frontend.Variable) { if i > 0 { sbb.WriteByte(' ') } - if v, ok := arg.(expr.TermToRefactor); ok { + if v, ok := arg.(expr.Term); ok { sbb.WriteString("%s") // we set limits to the linear expression, so that the log printer // can evaluate it before printing it - log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.TOREFACTORMakeTerm(&builder.st.Coeffs[v.CID], v.VID)}) + log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.cs.MakeTerm(v.Coeff, v.VID)}) } else { builder.printArg(&log, &sbb, arg) } @@ -463,7 +533,7 @@ func (builder *scs) Println(a ...frontend.Variable) { builder.cs.AddLog(log) } -func (builder *scs) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { +func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { leafCount, err := schema.Walk(a, tVariable, nil) count := leafCount.Public + leafCount.Secret @@ -484,10 +554,10 @@ func (builder *scs) printArg(log *constraint.LogEntry, sbb *strings.Builder, a f sbb.WriteString(", ") } - v := tValue.Interface().(expr.TermToRefactor) + v := tValue.Interface().(expr.Term) // we set limits to the linear expression, so that the log printer // can evaluate it before printing it - log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.TOREFACTORMakeTerm(&builder.st.Coeffs[v.CID], v.VID)}) + log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.cs.MakeTerm(v.Coeff, v.VID)}) return nil } // ignoring error, printer() doesn't return errors @@ -495,6 +565,57 @@ func (builder *scs) printArg(log *constraint.LogEntry, sbb *strings.Builder, a f sbb.WriteByte('}') } -func (builder *scs) Compiler() frontend.Compiler { +func (builder *builder) Compiler() frontend.Compiler { return builder } + +func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { + + commitments := builder.cs.GetCommitments().(constraint.PlonkCommitments) + v = filterConstants(v) // TODO: @Tabaie Settle on a way to represent even constants; conventional hash? + + committed := make([]int, len(v)) + + for i, vI := range v { // TODO @Tabaie Perf; If public, just hash it + vINeg := builder.Neg(vI).(expr.Term) + committed[i] = builder.cs.GetNbConstraints() + // a constraint to enforce consistency between the commitment and committed value + // - v + comm(n) = 0 + builder.addPlonkConstraint(sparseR1C{xa: vINeg.VID, qL: vINeg.Coeff, commitment: constraint.COMMITTED}) + } + + hintId, err := cs.RegisterBsb22CommitmentComputePlaceholder(len(commitments)) + if err != nil { + return nil, err + } + + var outs []frontend.Variable + if outs, err = builder.NewHintForId(hintId, 1, v...); err != nil { + return nil, err + } + + commitmentVar := builder.Neg(outs[0]).(expr.Term) + commitmentConstraintIndex := builder.cs.GetNbConstraints() + // RHS will be provided by both prover and verifier independently, as for a public wire + builder.addPlonkConstraint(sparseR1C{xa: commitmentVar.VID, qL: commitmentVar.Coeff, commitment: constraint.COMMITMENT}) // value will be injected later + + return outs[0], builder.cs.AddCommitment(constraint.PlonkCommitment{ + HintID: hintId, + CommitmentIndex: commitmentConstraintIndex, + Committed: committed, + }) +} + +func filterConstants(v []frontend.Variable) []frontend.Variable { + res := make([]frontend.Variable, 0, len(v)) + for _, vI := range v { + if _, ok := vI.(expr.Term); ok { + res = append(res, vI) + } + } + return res +} + +func (*builder) FrontendType() frontendtype.Type { + return frontendtype.SCS +} diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index f30b177666..a4e4447b17 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -20,7 +20,7 @@ import ( "fmt" "math/big" - "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/internal/utils" @@ -28,13 +28,13 @@ import ( ) // AssertIsEqual fails if i1 != i2 -func (builder *scs) AssertIsEqual(i1, i2 frontend.Variable) { +func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { - c1, i1Constant := builder.ConstantValue(i1) - c2, i2Constant := builder.ConstantValue(i2) + c1, i1Constant := builder.constantValue(i1) + c2, i2Constant := builder.constantValue(i2) if i1Constant && i2Constant { - if c1.Cmp(c2) != 0 { + if c1 != c2 { panic("i1, i2 should be equal") } return @@ -45,62 +45,100 @@ func (builder *scs) AssertIsEqual(i1, i2 frontend.Variable) { c2 = c1 } if i2Constant { - l := i1.(expr.TermToRefactor) - lc, _ := l.Unpack() - k := c2 - debug := builder.newDebugInfo("assertIsEqual", l, "+", i2, " == 0") - k.Neg(k) - _k := builder.st.CoeffID(k) - builder.addPlonkConstraint(l, builder.zero(), builder.zero(), lc, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, _k, debug) + xa := i1.(expr.Term) + c2 := builder.cs.Neg(c2) + + // xa - i2 == 0 + toAdd := sparseR1C{ + xa: xa.VID, + qL: xa.Coeff, + qC: c2, + } + + if debug.Debug { + debug := builder.newDebugInfo("assertIsEqual", xa, "==", i2) + builder.addPlonkConstraint(toAdd, debug) + } else { + builder.addPlonkConstraint(toAdd) + } return } - l := i1.(expr.TermToRefactor) - r := builder.Neg(i2).(expr.TermToRefactor) - lc, _ := l.Unpack() - rc, _ := r.Unpack() + xa := i1.(expr.Term) + xb := i2.(expr.Term) + + xb.Coeff = builder.cs.Neg(xb.Coeff) + // xa - xb == 0 + toAdd := sparseR1C{ + xa: xa.VID, + xb: xb.VID, + qL: xa.Coeff, + qR: xb.Coeff, + } - debug := builder.newDebugInfo("assertIsEqual", l, " + ", r, " == 0") - builder.addPlonkConstraint(l, r, builder.zero(), lc, rc, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, debug) + if debug.Debug { + debug := builder.newDebugInfo("assertIsEqual", xa, " == ", xb) + builder.addPlonkConstraint(toAdd, debug) + } else { + builder.addPlonkConstraint(toAdd) + } } // AssertIsDifferent fails if i1 == i2 -func (builder *scs) AssertIsDifferent(i1, i2 frontend.Variable) { - builder.Inverse(builder.Sub(i1, i2)) +func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { + s := builder.Sub(i1, i2) + if c, ok := builder.constantValue(s); ok && c.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } else if t := s.(expr.Term); t.Coeff.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } + builder.Inverse(s) } // AssertIsBoolean fails if v != 0 ∥ v != 1 -func (builder *scs) AssertIsBoolean(i1 frontend.Variable) { - if c, ok := builder.ConstantValue(i1); ok { - if !(c.IsUint64() && (c.Uint64() == 0 || c.Uint64() == 1)) { - panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", c.String())) +func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { + if c, ok := builder.constantValue(i1); ok { + if !(c.IsZero() || builder.cs.IsOne(c)) { + panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", builder.cs.String(c))) } return } - t := i1.(expr.TermToRefactor) - if builder.IsBoolean(t) { + + v := i1.(expr.Term) + if builder.IsBoolean(v) { return } - builder.MarkBoolean(t) - builder.mtBooleans[int(t.CID)|t.VID<<32] = struct{}{} // TODO @gbotrel smelly fix me - debug := builder.newDebugInfo("assertIsBoolean", t, " == (0|1)") - cID, _ := t.Unpack() - var mCoef big.Int - mCoef.Neg(&builder.st.Coeffs[cID]) - mcID := builder.st.CoeffID(&mCoef) - builder.addPlonkConstraint(t, t, builder.zero(), cID, constraint.CoeffIdZero, mcID, cID, constraint.CoeffIdZero, constraint.CoeffIdZero, debug) + builder.MarkBoolean(v) + + // ensure v * (1 - v) == 0 + // that is v + -v*v == 0 + // qM = -v.Coeff*v.Coeff + qM := builder.cs.Neg(v.Coeff) + qM = builder.cs.Mul(qM, v.Coeff) + toAdd := sparseR1C{ + xa: v.VID, + qL: v.Coeff, + qM: qM, + } + if debug.Debug { + debug := builder.newDebugInfo("assertIsBoolean", v, " == (0|1)") + builder.addBoolGate(toAdd, debug) + } else { + builder.addBoolGate(toAdd) + } + } // AssertIsLessOrEqual fails if v > bound -func (builder *scs) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { +func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { switch b := bound.(type) { - case expr.TermToRefactor: - builder.mustBeLessOrEqVar(v.(expr.TermToRefactor), b) + case expr.Term: + builder.mustBeLessOrEqVar(v, b) default: - builder.mustBeLessOrEqCst(v.(expr.TermToRefactor), utils.FromInterface(b)) + builder.mustBeLessOrEqCst(v, utils.FromInterface(b)) } } -func (builder *scs) mustBeLessOrEqVar(a expr.TermToRefactor, bound expr.TermToRefactor) { +func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term) { debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -126,28 +164,33 @@ func (builder *scs) mustBeLessOrEqVar(a expr.TermToRefactor, bound expr.TermToRe t := builder.Select(boundBits[i], 0, p[i+1]) // (1 - t - ai) * ai == 0 - l := builder.Sub(1, t, aBits[i]) + l := builder.Sub(1, t, aBits[i]).(expr.Term) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i].(expr.TermToRefactor)) // this does not create a constraint - builder.addPlonkConstraint( - l.(expr.TermToRefactor), - aBits[i].(expr.TermToRefactor), - builder.zero(), - constraint.CoeffIdZero, - constraint.CoeffIdZero, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdZero, - constraint.CoeffIdZero, debug) + if ai, ok := builder.constantValue(aBits[i]); ok { + // a is constant; ensure l == 0 + l.Coeff = builder.cs.Mul(l.Coeff, ai) + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + qL: l.Coeff, + }, debug) + } else { + // l * a[i] == 0 + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + xb: aBits[i].(expr.Term).VID, + qM: l.Coeff, + }, debug) + } + } } -func (builder *scs) mustBeLessOrEqCst(a expr.TermToRefactor, bound big.Int) { +func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) { nbBits := builder.cs.FieldBitLen() @@ -159,6 +202,14 @@ func (builder *scs) mustBeLessOrEqCst(a expr.TermToRefactor, bound big.Int) { panic("AssertIsLessOrEqual: bound is too large, constraint will never be satisfied") } + if ca, ok := builder.constantValue(a); ok { + // a is constant, compare the big int values + ba := builder.cs.ToBigInt(ca) + if ba.Cmp(&bound) == 1 { + panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", ba.String(), bound.String())) + } + } + // debug info debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -191,21 +242,14 @@ func (builder *scs) mustBeLessOrEqCst(a expr.TermToRefactor, bound big.Int) { if bound.Bit(i) == 0 { // (1 - p(i+1) - ai) * ai == 0 - l := builder.Sub(1, p[i+1], aBits[i]).(expr.TermToRefactor) + l := builder.Sub(1, p[i+1], aBits[i]).(expr.Term) //l = builder.Sub(l, ).(term) - builder.addPlonkConstraint( - l, - aBits[i].(expr.TermToRefactor), - builder.zero(), - constraint.CoeffIdZero, - constraint.CoeffIdZero, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdZero, - constraint.CoeffIdZero, - debug) - // builder.markBoolean(aBits[i].(term)) + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + xb: aBits[i].(expr.Term).VID, + qM: builder.tOne, + }, debug) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index a49547bcaa..eee785663a 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -17,18 +17,18 @@ limitations under the License. package scs import ( - "fmt" "math/big" "reflect" "sort" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -40,6 +40,7 @@ import ( bn254r1cs "github.com/consensys/gnark/constraint/bn254" bw6633r1cs "github.com/consensys/gnark/constraint/bw6-633" bw6761r1cs "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/constraint/solver" tinyfieldr1cs "github.com/consensys/gnark/constraint/tinyfield" ) @@ -47,135 +48,204 @@ func NewBuilder(field *big.Int, config frontend.CompileConfig) (frontend.Builder return newBuilder(field, config), nil } -type scs struct { - cs constraint.SparseR1CS - - st cs.CoeffTable +type builder struct { + cs constraint.SparseR1CS config frontend.CompileConfig + kvstore.Store // map for recording boolean constrained variables (to not constrain them twice) - mtBooleans map[int]struct{} + mtBooleans map[expr.Term]struct{} + + // records multiplications constraint to avoid duplicate. + // see mulConstraintExist(...) + mMulInstructions map[uint64]int + + // same thing for addition gates + // see addConstraintExist(...) + mAddInstructions map[uint64]int - q *big.Int + // frequently used coefficients + tOne, tMinusOne constraint.Element + + genericGate constraint.BlueprintID + mulGate, addGate, boolGate constraint.BlueprintID + + // used to avoid repeated allocations + bufL expr.LinearExpression + bufH []constraint.LinearExpression } // initialCapacity has quite some impact on frontend performance, especially on large circuits size // we may want to add build tags to tune that -// TODO @gbotrel restore capacity option! -func newBuilder(field *big.Int, config frontend.CompileConfig) *scs { - builder := scs{ - mtBooleans: make(map[int]struct{}), - st: cs.NewCoeffTable(), - config: config, +func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { + b := builder{ + mtBooleans: make(map[expr.Term]struct{}), + mMulInstructions: make(map[uint64]int, config.Capacity/2), + mAddInstructions: make(map[uint64]int, config.Capacity/2), + config: config, + Store: kvstore.New(), + bufL: make(expr.LinearExpression, 20), } + // init hint buffer. + _ = b.hintBuffer(256) curve := utils.FieldToCurve(field) switch curve { case ecc.BLS12_377: - builder.cs = bls12377r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls12377r1cs.NewSparseR1CS(config.Capacity) case ecc.BLS12_381: - builder.cs = bls12381r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls12381r1cs.NewSparseR1CS(config.Capacity) case ecc.BN254: - builder.cs = bn254r1cs.NewSparseR1CS(config.Capacity) + b.cs = bn254r1cs.NewSparseR1CS(config.Capacity) case ecc.BW6_761: - builder.cs = bw6761r1cs.NewSparseR1CS(config.Capacity) + b.cs = bw6761r1cs.NewSparseR1CS(config.Capacity) case ecc.BW6_633: - builder.cs = bw6633r1cs.NewSparseR1CS(config.Capacity) + b.cs = bw6633r1cs.NewSparseR1CS(config.Capacity) case ecc.BLS24_315: - builder.cs = bls24315r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls24315r1cs.NewSparseR1CS(config.Capacity) case ecc.BLS24_317: - builder.cs = bls24317r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls24317r1cs.NewSparseR1CS(config.Capacity) default: if field.Cmp(tinyfield.Modulus()) == 0 { - builder.cs = tinyfieldr1cs.NewSparseR1CS(config.Capacity) + b.cs = tinyfieldr1cs.NewSparseR1CS(config.Capacity) break } - panic("not implemtented") + panic("not implemented") } - builder.q = builder.cs.Field() - if builder.q.Cmp(field) != 0 { - panic("invalid modulus on cs impl") // sanity check - } + b.tOne = b.cs.One() + b.tMinusOne = b.cs.FromInterface(-1) - return &builder + b.genericGate = b.cs.AddBlueprint(&constraint.BlueprintGenericSparseR1C{}) + b.mulGate = b.cs.AddBlueprint(&constraint.BlueprintSparseR1CMul{}) + b.addGate = b.cs.AddBlueprint(&constraint.BlueprintSparseR1CAdd{}) + b.boolGate = b.cs.AddBlueprint(&constraint.BlueprintSparseR1CBool{}) + + return &b } -func (builder *scs) Field() *big.Int { +func (builder *builder) Field() *big.Int { return builder.cs.Field() } -func (builder *scs) FieldBitLen() int { +func (builder *builder) FieldBitLen() int { return builder.cs.FieldBitLen() } -// addPlonkConstraint creates a constraint of the for al+br+clr+k=0 +// TODO @gbotrel doing a 2-step refactoring for now, frontend only. need to update constraint/SparseR1C. // qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0 -func (builder *scs) addPlonkConstraint(xa, xb, xc expr.TermToRefactor, qL, qR, qM1, qM2, qO, qC int, debug ...constraint.DebugInfo) { - // TODO @gbotrel the signature of this function is odd.. and confusing. need refactor. - // TODO @gbotrel restore debug info - // if len(debugID) > 0 { - // builder.MDebug[len(builder.Constraints)] = debugID[0] - // } else if debug.Debug { - // builder.MDebug[len(builder.Constraints)] = constraint.NewDebugInfo("") - // } - - xa.SetCoeffID(qL) - xb.SetCoeffID(qR) - xc.SetCoeffID(qO) - - u := xa - v := xb - u.SetCoeffID(qM1) - v.SetCoeffID(qM2) - L := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[xa.CID], xa.VID) - R := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[xb.CID], xb.VID) - O := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[xc.CID], xc.VID) - U := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[u.CID], u.VID) - V := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[v.CID], v.VID) - K := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[qC], 0) - K.MarkConstant() - builder.cs.AddConstraint(constraint.SparseR1C{L: L, R: R, O: O, M: [2]constraint.Term{U, V}, K: K.CoeffID()}, debug...) +type sparseR1C struct { + xa, xb, xc int // wires + qL, qR, qO, qM, qC constraint.Element // coefficients + commitment constraint.CommitmentConstraint +} + +// a * b == c +func (builder *builder) addMulGate(a, b, c expr.Term) { + qM := builder.cs.Mul(a.Coeff, b.Coeff) + QM := builder.cs.AddCoeff(qM) + + builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(a.VID), + XB: uint32(b.VID), + XC: uint32(c.VID), + QM: QM, + QO: constraint.CoeffIdMinusOne, + }, builder.mulGate) +} + +// a + b + k == c +func (builder *builder) addAddGate(a, b expr.Term, xc uint32, k constraint.Element) { + qL := builder.cs.AddCoeff(a.Coeff) + qR := builder.cs.AddCoeff(b.Coeff) + qC := builder.cs.AddCoeff(k) + + builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(a.VID), + XB: uint32(b.VID), + XC: xc, + QL: qL, + QR: qR, + QC: qC, + QO: constraint.CoeffIdMinusOne, + }, builder.addGate) +} + +func (builder *builder) addBoolGate(c sparseR1C, debugInfo ...constraint.DebugInfo) { + QL := builder.cs.AddCoeff(c.qL) + QM := builder.cs.AddCoeff(c.qM) + + cID := builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(c.xa), + QL: QL, + QM: QM}, + builder.boolGate) + if debug.Debug && len(debugInfo) == 1 { + builder.cs.AttachDebugInfo(debugInfo[0], []int{cID}) + } +} + +// addPlonkConstraint adds a sparseR1C to the underlying constraint system +func (builder *builder) addPlonkConstraint(c sparseR1C, debugInfo ...constraint.DebugInfo) { + if !c.qM.IsZero() && (c.xa == 0 || c.xb == 0) { + // TODO this is internal but not easy to detect; if qM is set, but one or both of xa / xb is not, + // since wireID == 0 is a valid wire, it may trigger unexpected behavior. + log := logger.Logger() + log.Warn().Msg("adding a plonk constraint with qM set but xa or xb == 0 (wire 0)") + } + QL := builder.cs.AddCoeff(c.qL) + QR := builder.cs.AddCoeff(c.qR) + QO := builder.cs.AddCoeff(c.qO) + QM := builder.cs.AddCoeff(c.qM) + QC := builder.cs.AddCoeff(c.qC) + + cID := builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(c.xa), + XB: uint32(c.xb), + XC: uint32(c.xc), + QL: QL, + QR: QR, + QO: QO, + QM: QM, + QC: QC, Commitment: c.commitment}, builder.genericGate) + if debug.Debug && len(debugInfo) == 1 { + builder.cs.AttachDebugInfo(debugInfo[0], []int{cID}) + } } // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets // the wire's id to the number of wires, and returns it -func (builder *scs) newInternalVariable() expr.TermToRefactor { +func (builder *builder) newInternalVariable() expr.Term { idx := builder.cs.AddInternalVariable() - return expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + return expr.NewTerm(idx, builder.tOne) } // PublicVariable creates a new Public Variable -func (builder *scs) PublicVariable(f schema.LeafInfo) frontend.Variable { +func (builder *builder) PublicVariable(f schema.LeafInfo) frontend.Variable { idx := builder.cs.AddPublicVariable(f.FullName()) - return expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + return expr.NewTerm(idx, builder.tOne) } // SecretVariable creates a new Secret Variable -func (builder *scs) SecretVariable(f schema.LeafInfo) frontend.Variable { +func (builder *builder) SecretVariable(f schema.LeafInfo) frontend.Variable { idx := builder.cs.AddSecretVariable(f.FullName()) - return expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + return expr.NewTerm(idx, builder.tOne) } // reduces redundancy in linear expression // It factorizes Variable that appears multiple times with != coeff Ids // To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID -func (builder *scs) reduce(l expr.LinearExpressionToRefactor) expr.LinearExpressionToRefactor { +func (builder *builder) reduce(l expr.LinearExpression) expr.LinearExpression { // ensure our linear expression is sorted, by visibility and by Variable ID sort.Sort(l) - c := new(big.Int) for i := 1; i < len(l); i++ { - pcID, pvID := l[i-1].Unpack() - ccID, cvID := l[i].Unpack() - if pvID == cvID { + if l[i-1].VID == l[i].VID { // we have redundancy - c.Add(&builder.st.Coeffs[pcID], &builder.st.Coeffs[ccID]) - c.Mod(c, builder.q) - l[i-1].SetCoeffID(builder.st.CoeffID(c)) + l[i-1].Coeff = builder.cs.Add(l[i-1].Coeff, l[i].Coeff) l = append(l[:i], l[i+1:]...) i-- } @@ -183,33 +253,28 @@ func (builder *scs) reduce(l expr.LinearExpressionToRefactor) expr.LinearExpress return l } -// to handle wires that don't exist (=coef 0) in a sparse constraint -func (builder *scs) zero() expr.TermToRefactor { - var a expr.TermToRefactor - return a -} - // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. -func (builder *scs) IsBoolean(v frontend.Variable) bool { - if b, ok := builder.ConstantValue(v); ok { - return b.IsUint64() && b.Uint64() <= 1 +func (builder *builder) IsBoolean(v frontend.Variable) bool { + if b, ok := builder.constantValue(v); ok { + return (b.IsZero() || builder.cs.IsOne(b)) } - _, ok := builder.mtBooleans[int(v.(expr.TermToRefactor).CID|(int(v.(expr.TermToRefactor).VID)<<32))] // TODO @gbotrel fixme this is sketchy + _, ok := builder.mtBooleans[v.(expr.Term)] return ok } // MarkBoolean sets (but do not constraint!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. -func (builder *scs) MarkBoolean(v frontend.Variable) { - if b, ok := builder.ConstantValue(v); ok { - if !(b.IsUint64() && b.Uint64() <= 1) { +func (builder *builder) MarkBoolean(v frontend.Variable) { + if _, ok := builder.constantValue(v); ok { + if !builder.IsBoolean(v) { panic("MarkBoolean called a non-boolean constant") } + return } - builder.mtBooleans[int(v.(expr.TermToRefactor).CID|(int(v.(expr.TermToRefactor).VID)<<32))] = struct{}{} // TODO @gbotrel fixme this is sketchy + builder.mtBooleans[v.(expr.Term)] = struct{}{} } var tVariable reflect.Type @@ -218,7 +283,7 @@ func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() } -func (builder *scs) Compile() (constraint.ConstraintSystem, error) { +func (builder *builder) Compile() (constraint.ConstraintSystem, error) { log := logger.Logger() log.Info(). Int("nbConstraints", builder.cs.GetNbConstraints()). @@ -236,21 +301,32 @@ func (builder *scs) Compile() (constraint.ConstraintSystem, error) { return builder.cs, nil } -// ConstantValue returns the big.Int value of v. It -// panics if v.IsConstant() == false -func (builder *scs) ConstantValue(v frontend.Variable) (*big.Int, bool) { - switch t := v.(type) { - case expr.TermToRefactor: +// ConstantValue returns the big.Int value of v. +// Will panic if v.IsConstant() == false +func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { + coeff, ok := builder.constantValue(v) + if !ok { return nil, false - default: - res := utils.FromInterface(t) - return &res, true } + return builder.cs.ToBigInt(coeff), true } -func (builder *scs) TOREFACTORMakeTerm(c *big.Int, vID int) constraint.Term { - cc := builder.cs.FromInterface(c) - return builder.cs.MakeTerm(&cc, vID) +func (builder *builder) constantValue(v frontend.Variable) (constraint.Element, bool) { + if _, ok := v.(expr.Term); ok { + return constraint.Element{}, false + } + return builder.cs.FromInterface(v), true +} + +func (builder *builder) hintBuffer(size int) []constraint.LinearExpression { + if cap(builder.bufH) < size { + builder.bufH = make([]constraint.LinearExpression, 2*size) + for i := 0; i < len(builder.bufH); i++ { + builder.bufH[i] = make(constraint.LinearExpression, 1) + } + } + + return builder.bufH[:size] } // NewHint initializes internal variables whose value will be evaluated using @@ -265,24 +341,31 @@ func (builder *scs) TOREFACTORMakeTerm(c *big.Int, vID int) constraint.Term { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (builder *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (builder *builder) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(f, solver.GetHintID(f), nbOutputs, inputs...) +} - hintInputs := make([]constraint.LinearExpression, len(inputs)) +func (builder *builder) NewHintForId(id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(nil, id, nbOutputs, inputs...) +} + +func (builder *builder) newHint(f solver.Hint, id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + hintInputs := builder.hintBuffer(len(inputs)) // ensure inputs are set and pack them in a []uint64 for i, in := range inputs { switch t := in.(type) { - case expr.TermToRefactor: - hintInputs[i] = constraint.LinearExpression{builder.TOREFACTORMakeTerm(&builder.st.Coeffs[t.CID], t.VID)} + case expr.Term: + hintInputs[i][0] = builder.cs.MakeTerm(t.Coeff, t.VID) default: - c := utils.FromInterface(in) - term := builder.TOREFACTORMakeTerm(&c, 0) + c := builder.cs.FromInterface(in) + term := builder.cs.MakeTerm(c, 0) term.MarkConstant() - hintInputs[i] = constraint.LinearExpression{term} + hintInputs[i][0] = term } } - internalVariables, err := builder.cs.AddSolverHint(f, hintInputs, nbOutputs) + internalVariables, err := builder.cs.AddSolverHint(f, id, hintInputs, nbOutputs) if err != nil { return nil, err } @@ -290,104 +373,317 @@ func (builder *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.V // make the variables res := make([]frontend.Variable, len(internalVariables)) for i, idx := range internalVariables { - res[i] = expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + res[i] = expr.NewTerm(idx, builder.tOne) } return res, nil } // returns in split into a slice of compiledTerm and the sum of all constants in in as a bigInt -func (builder *scs) filterConstantSum(in []frontend.Variable) (expr.LinearExpressionToRefactor, big.Int) { - res := make(expr.LinearExpressionToRefactor, 0, len(in)) - var b big.Int +func (builder *builder) filterConstantSum(in []frontend.Variable) (expr.LinearExpression, constraint.Element) { + var res expr.LinearExpression + if len(in) <= cap(builder.bufL) { + // we can use the temp buffer + res = builder.bufL[:0] + } else { + res = make(expr.LinearExpression, 0, len(in)) + } + + b := constraint.Element{} for i := 0; i < len(in); i++ { - switch t := in[i].(type) { - case expr.TermToRefactor: - res = append(res, t) - default: - n := utils.FromInterface(t) - b.Add(&b, &n) + if c, ok := builder.constantValue(in[i]); ok { + b = builder.cs.Add(b, c) + } else { + res = append(res, in[i].(expr.Term)) } } return res, b } -// returns in split into a slice of compiledTerm and the product of all constants in in as a bigInt -func (builder *scs) filterConstantProd(in []frontend.Variable) (expr.LinearExpressionToRefactor, big.Int) { - res := make(expr.LinearExpressionToRefactor, 0, len(in)) - var b big.Int - b.SetInt64(1) +// returns in split into a slice of compiledTerm and the product of all constants in in as a coeff +func (builder *builder) filterConstantProd(in []frontend.Variable) (expr.LinearExpression, constraint.Element) { + var res expr.LinearExpression + if len(in) <= cap(builder.bufL) { + // we can use the temp buffer + res = builder.bufL[:0] + } else { + res = make(expr.LinearExpression, 0, len(in)) + } + + b := builder.tOne for i := 0; i < len(in); i++ { - switch t := in[i].(type) { - case expr.TermToRefactor: - res = append(res, t) - default: - n := utils.FromInterface(t) - b.Mul(&b, &n).Mod(&b, builder.q) + if c, ok := builder.constantValue(in[i]); ok { + b = builder.cs.Mul(b, c) + } else { + res = append(res, in[i].(expr.Term)) } } return res, b } -func (builder *scs) splitSum(acc expr.TermToRefactor, r expr.LinearExpressionToRefactor) expr.TermToRefactor { - +func (builder *builder) splitSum(acc expr.Term, r expr.LinearExpression, k *constraint.Element) expr.Term { // floor case if len(r) == 0 { + if k != nil { + // we need to return acc + k + o, found := builder.addConstraintExist(acc, expr.Term{}, *k) + if !found { + o = builder.newInternalVariable() + builder.addAddGate(acc, expr.Term{}, uint32(o.VID), *k) + } + + return o + } return acc } - cl, _ := acc.Unpack() - cr, _ := r[0].Unpack() - o := builder.newInternalVariable() - builder.addPlonkConstraint(acc, r[0], o, cl, cr, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, constraint.CoeffIdZero) - return builder.splitSum(o, r[1:]) + // constraint to add: acc + r[0] (+ k) == o + qC := constraint.Element{} + if k != nil { + qC = *k + } + o, found := builder.addConstraintExist(acc, r[0], qC) + if !found { + o = builder.newInternalVariable() + builder.addAddGate(acc, r[0], uint32(o.VID), qC) + } + + return builder.splitSum(o, r[1:], nil) +} + +// addConstraintExist check if we recorded a constraint in the form +// q1*xa + q2*xb + qC - xc == 0 +// +// if we find one, this function returns the xc wire with the correct coefficients. +// if we don't, and no previous addition was recorded with xa and xb, add an entry in the map +// (this assumes that the caller will add a constraint just after this call if it's not found!) +// +// idea: +// 1. take (xa | (xb << 32)) as a identifier of an addition that used wires xa and xb. +// 2. look for an entry in builder.mAddConstraints for a previously added constraint that matches. +// 3. if so, check that the coefficients matches and we can re-use xc wire. +// +// limitations: +// 1. for efficiency, we just store the first addition that occurred with with xa and xb; +// so if we do 2*xa + 3*xb == c, then want to compute xa + xb == d multiple times, the compiler is +// not going to catch these duplicates. +// 2. this piece of code assumes some behavior from constraint/ package (like coeffIDs, or append-style +// constraint management) +func (builder *builder) addConstraintExist(a, b expr.Term, k constraint.Element) (expr.Term, bool) { + // ensure deterministic combined identifier; + if a.VID < b.VID { + a, b = b, a + } + h := uint64(a.WireID()) | uint64(b.WireID()<<32) + + if iID, ok := builder.mAddInstructions[h]; ok { + // if we do custom gates with slices in the constraint + // should use a shared object here to avoid allocs. + var c constraint.SparseR1C + + // seems likely we have a fit, let's double check + inst := builder.cs.GetInstruction(iID) + // we know the blueprint we added it. + blueprint := constraint.BlueprintSparseR1CAdd{} + blueprint.DecompressSparseR1C(&c, inst) + + // qO == -1 + if a.WireID() == int(c.XB) { + a, b = b, a // ensure a is in qL + } + + tk := builder.cs.MakeTerm(k, 0) + if tk.CoeffID() != int(c.QC) { + // the constant part of the addition differs, no point going forward + // since we will need to add a new constraint anyway. + return expr.Term{}, false + } + + // check that the coeff matches + qL := a.Coeff + qR := b.Coeff + ta := builder.cs.MakeTerm(qL, 0) + tb := builder.cs.MakeTerm(qR, 0) + if int(c.QL) != ta.CoeffID() || int(c.QR) != tb.CoeffID() { + if !k.IsZero() { + // may be for some edge cases we could avoid adding a constraint here. + return expr.Term{}, false + } + // we recorded an addition in the form q1*a + q2*b == c + // we want to record a new one in the form q3*a + q4*b == n*c + // question is; can we re-use c to avoid introducing a new wire & new constraint + // this is possible only if n == q3/q1 == q4/q2, that is, q3q2 == q1q4 + q1 := builder.cs.GetCoefficient(int(c.QL)) + q2 := builder.cs.GetCoefficient(int(c.QR)) + q3 := qL + q4 := qR + q3 = builder.cs.Mul(q3, q2) + q1 = builder.cs.Mul(q1, q4) + if q1 == q3 { + // no need to introduce a new constraint; + // compute n, the coefficient for the output wire + q2, ok = builder.cs.Inverse(q2) + if !ok { + panic("div by 0") // shouldn't happen + } + q2 = builder.cs.Mul(q2, q4) + return expr.NewTerm(int(c.XC), q2), true + } + // we will need an additional constraint + return expr.Term{}, false + } + + // we found the same constraint! + return expr.NewTerm(int(c.XC), builder.tOne), true + } + // we are going to add this constraint, so we mark it. + // ! assumes the caller add an instruction immediately after the call to this function + builder.mAddInstructions[h] = builder.cs.GetNbInstructions() + return expr.Term{}, false } -func (builder *scs) splitProd(acc expr.TermToRefactor, r expr.LinearExpressionToRefactor) expr.TermToRefactor { +// mulConstraintExist check if we recorded a constraint in the form +// qM*xa*xb - xc == 0 +// +// if we find one, this function returns the xc wire with the correct coefficients. +// if we don't, and no previous multiplication was recorded with xa and xb, add an entry in the map +// (this assumes that the caller will add a constraint just after this call if it's not found!) +// +// idea: +// 1. take (xa | (xb << 32)) as a identifier of a multiplication that used wires xa and xb. +// 2. look for an entry in builder.mMulConstraints for a previously added constraint that matches. +// 3. if so, compute correct coefficient N for xc wire that matches qM'*xa*xb - N*xc == 0 +// +// limitations: +// 1. this piece of code assumes some behavior from constraint/ package (like coeffIDs, or append-style +// constraint management) +func (builder *builder) mulConstraintExist(a, b expr.Term) (expr.Term, bool) { + // ensure deterministic combined identifier; + if a.VID < b.VID { + a, b = b, a + } + h := uint64(a.WireID()) | uint64(b.WireID()<<32) + + if iID, ok := builder.mMulInstructions[h]; ok { + // if we do custom gates with slices in the constraint + // should use a shared object here to avoid allocs. + var c constraint.SparseR1C + + // seems likely we have a fit, let's double check + inst := builder.cs.GetInstruction(iID) + // we know the blueprint we added it. + blueprint := constraint.BlueprintSparseR1CMul{} + blueprint.DecompressSparseR1C(&c, inst) + + // qO == -1 + + if a.WireID() == int(c.XB) { + a, b = b, a // ensure a is in qL + } + + // recompute the qM coeff and check that it matches; + qM := builder.cs.Mul(a.Coeff, b.Coeff) + tm := builder.cs.MakeTerm(qM, 0) + if int(c.QM) != tm.CoeffID() { + // so we wanted to compute + // N * xC == qM*xA*xB + // but found a constraint + // xC == qM'*xA*xB + // the coefficient for our resulting wire is different; + // N = qM / qM' + N := builder.cs.GetCoefficient(int(c.QM)) + N, ok := builder.cs.Inverse(N) + if !ok { + panic("div by 0") // sanity check. + } + N = builder.cs.Mul(N, qM) + + return expr.NewTerm(int(c.XC), N), true + } + // we found the exact same constraint + return expr.NewTerm(int(c.XC), builder.tOne), true + } + + // we are going to add this constraint, so we mark it. + // ! assumes the caller add an instruction immediately after the call to this function + builder.mMulInstructions[h] = builder.cs.GetNbInstructions() + return expr.Term{}, false +} + +func (builder *builder) splitProd(acc expr.Term, r expr.LinearExpression) expr.Term { // floor case if len(r) == 0 { return acc } + // we want to add a constraint such that acc * r[0] == o + // let's check if we didn't already constrain a similar product + o, found := builder.mulConstraintExist(acc, r[0]) + + if !found { + // constraint to add: acc * r[0] == o + o = builder.newInternalVariable() + builder.addMulGate(acc, r[0], o) + } - cl, _ := acc.Unpack() - cr, _ := r[0].Unpack() - o := builder.newInternalVariable() - builder.addPlonkConstraint(acc, r[0], o, constraint.CoeffIdZero, constraint.CoeffIdZero, cl, cr, constraint.CoeffIdMinusOne, constraint.CoeffIdZero) return builder.splitProd(o, r[1:]) } -func (builder *scs) Commit(v ...frontend.Variable) (frontend.Variable, error) { - return nil, fmt.Errorf("not implemented") -} - -func (builder *scs) SetGkrInfo(info constraint.GkrInfo) error { - return fmt.Errorf("not implemented") -} - // newDebugInfo this is temporary to restore debug logs // something more like builder.sprintf("my message %le %lv", l0, l1) // to build logs for both debug and println // and append some program location.. (see other todo in debug_info.go) -func (builder *scs) newDebugInfo(errName string, in ...interface{}) constraint.DebugInfo { +func (builder *builder) newDebugInfo(errName string, in ...interface{}) constraint.DebugInfo { for i := 0; i < len(in); i++ { // for inputs that are LinearExpressions or Term, we need to "Make" them in the backend. // TODO @gbotrel this is a duplicate effort with adding a constraint and should be taken care off switch t := in[i].(type) { - case *expr.LinearExpressionToRefactor, expr.LinearExpressionToRefactor: + case *expr.LinearExpression, expr.LinearExpression: // shouldn't happen - case expr.TermToRefactor: - in[i] = builder.TOREFACTORMakeTerm(&builder.st.Coeffs[t.CID], t.VID) - case *expr.TermToRefactor: - in[i] = builder.TOREFACTORMakeTerm(&builder.st.Coeffs[t.CID], t.VID) - case constraint.Coeff: - in[i] = builder.cs.String(&t) - case *constraint.Coeff: + case expr.Term: + in[i] = builder.cs.MakeTerm(t.Coeff, t.VID) + case *expr.Term: + in[i] = builder.cs.MakeTerm(t.Coeff, t.VID) + case constraint.Element: in[i] = builder.cs.String(t) + case *constraint.Element: + in[i] = builder.cs.String(*t) } } return builder.cs.NewDebugInfo(errName, in...) } + +func (builder *builder) Defer(cb func(frontend.API) error) { + circuitdefer.Put(builder, cb) +} + +// AddInstruction is used to add custom instructions to the constraint system. +func (builder *builder) AddInstruction(bID constraint.BlueprintID, calldata []uint32) []uint32 { + return builder.cs.AddInstruction(bID, calldata) +} + +// AddBlueprint adds a custom blueprint to the constraint system. +func (builder *builder) AddBlueprint(b constraint.Blueprint) constraint.BlueprintID { + return builder.cs.AddBlueprint(b) +} + +func (builder *builder) InternalVariable(wireID uint32) frontend.Variable { + return expr.NewTerm(int(wireID), builder.tOne) +} + +// ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable +// ! Experimental: use in conjunction with constraint.CustomizableSystem +func (builder *builder) ToCanonicalVariable(v frontend.Variable) frontend.CanonicalVariable { + switch t := v.(type) { + case expr.Term: + return builder.cs.MakeTerm(t.Coeff, t.VID) + default: + c := builder.cs.FromInterface(v) + term := builder.cs.MakeTerm(c, 0) + term.MarkConstant() + return term + } +} diff --git a/frontend/cs/scs/duplicate_test.go b/frontend/cs/scs/duplicate_test.go new file mode 100644 index 0000000000..e06051bb86 --- /dev/null +++ b/frontend/cs/scs/duplicate_test.go @@ -0,0 +1,94 @@ +package scs_test + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/stretchr/testify/require" +) + +type circuitDupAdd struct { + A, B frontend.Variable + R frontend.Variable +} + +func (c *circuitDupAdd) Define(api frontend.API) error { + + f := api.Add(c.A, c.B) // 1 constraint + f = api.Add(c.A, c.B, f) // 1 constraint + f = api.Add(c.A, c.B, f) // 1 constraint + + d := api.Add(api.Mul(c.A, 3), api.Mul(3, c.B)) // 3a + 3b --> 3 (a + b) shouldn't add a constraint. + e := api.Mul(api.Add(c.A, c.B), 3) // no constraints + + api.AssertIsEqual(f, e) // 1 constraint + api.AssertIsEqual(d, f) // 1 constraint + api.AssertIsEqual(c.R, e) // 1 constraint + + return nil +} + +func TestDuplicateAdd(t *testing.T) { + assert := require.New(t) + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuitDupAdd{}) + assert.NoError(err) + + assert.Equal(6, ccs.GetNbConstraints(), "comparing expected number of constraints") + + w, err := frontend.NewWitness(&circuitDupAdd{ + A: 13, + B: 42, + R: 165, + }, ecc.BN254.ScalarField()) + assert.NoError(err) + + _, err = ccs.Solve(w) + assert.NoError(err, "solving failed") +} + +type circuitDupMul struct { + A, B frontend.Variable + R1, R2 frontend.Variable +} + +func (c *circuitDupMul) Define(api frontend.API) error { + + f := api.Mul(c.A, c.B) // 1 constraint + f = api.Mul(c.A, c.B, f) // 1 constraint + f = api.Mul(c.A, c.B, f) // 1 constraint + // f == (a*b)**3 + + d := api.Mul(api.Mul(c.A, 2), api.Mul(3, c.B)) // no constraints + e := api.Mul(api.Mul(c.A, c.B), 1) // no constraints + e = api.Mul(e, e) // e**2 (no constraints) + e = api.Mul(e, api.Mul(c.A, c.B), 1) // e**3 (no constraints) + + api.AssertIsEqual(f, e) // 1 constraint + api.AssertIsEqual(d, c.R1) // 1 constraint + api.AssertIsEqual(c.R2, e) // 1 constraint + + return nil +} + +func TestDuplicateMul(t *testing.T) { + assert := require.New(t) + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuitDupMul{}) + assert.NoError(err) + + assert.Equal(6, ccs.GetNbConstraints(), "comparing expected number of constraints") + + w, err := frontend.NewWitness(&circuitDupMul{ + A: 13, + B: 42, + R1: (13 * 2) * (42 * 3), + R2: (13 * 42) * (13 * 42) * (13 * 42), + }, ecc.BN254.ScalarField()) + assert.NoError(err) + + _, err = ccs.Solve(w) + assert.NoError(err, "solving failed") +} diff --git a/frontend/internal/expr/linear_expression.go b/frontend/internal/expr/linear_expression.go index 23efcbc146..682bf0985d 100644 --- a/frontend/internal/expr/linear_expression.go +++ b/frontend/internal/expr/linear_expression.go @@ -4,44 +4,20 @@ import ( "github.com/consensys/gnark/constraint" ) -// TODO @gbotrel --> storing a UUID in the linear expressions would enable better perf -// in the frontends -> check a linear expression is boolean, or has been converted to a -// "backend" constraint.LinearExpresion ... and avoid duplicating work would be interesting. - type LinearExpression []Term -func (l LinearExpression) Clone() LinearExpression { - res := make(LinearExpression, len(l)) - copy(res, l) - return res -} - // NewLinearExpression helper to initialize a linear expression with one term -func NewLinearExpression(vID int, cID constraint.Coeff) LinearExpression { +func NewLinearExpression(vID int, cID constraint.Element) LinearExpression { return LinearExpression{Term{Coeff: cID, VID: vID}} } -func NewTerm(vID int, cID constraint.Coeff) Term { - return Term{Coeff: cID, VID: vID} -} - -type Term struct { - VID int - Coeff constraint.Coeff -} - -func (t *Term) SetCoeff(c constraint.Coeff) { - t.Coeff = c -} -func (t Term) WireID() int { - return t.VID -} - -func (t Term) HashCode() uint64 { - return t.Coeff[0]*29 + uint64(t.VID<<12) +func (l LinearExpression) Clone() LinearExpression { + res := make(LinearExpression, len(l)) + copy(res, l) + return res } -// Len return the lenght of the Variable (implements Sort interface) +// Len return the length of the Variable (implements Sort interface) func (l LinearExpression) Len() int { return len(l) } diff --git a/frontend/internal/expr/linear_expression_scs_torefactor.go b/frontend/internal/expr/linear_expression_scs_torefactor.go deleted file mode 100644 index 6547c63bb5..0000000000 --- a/frontend/internal/expr/linear_expression_scs_torefactor.go +++ /dev/null @@ -1,78 +0,0 @@ -package expr - -type LinearExpressionToRefactor []TermToRefactor - -func (l LinearExpressionToRefactor) Clone() LinearExpressionToRefactor { - res := make(LinearExpressionToRefactor, len(l)) - copy(res, l) - return res -} - -func NewTermToRefactor(vID, cID int) TermToRefactor { - return TermToRefactor{CID: cID, VID: vID} -} - -type TermToRefactor struct { - CID int - VID int -} - -func (t TermToRefactor) Unpack() (cID, vID int) { - return t.CID, t.VID -} - -func (t *TermToRefactor) SetCoeffID(cID int) { - t.CID = cID -} -func (t TermToRefactor) WireID() int { - return t.VID -} - -func (t TermToRefactor) HashCode() uint64 { - return uint64(t.CID) + uint64(t.VID<<32) -} - -// Len return the lenght of the Variable (implements Sort interface) -func (l LinearExpressionToRefactor) Len() int { - return len(l) -} - -// Equals returns true if both SORTED expressions are the same -// -// pre conditions: l and o are sorted -func (l LinearExpressionToRefactor) Equal(o LinearExpressionToRefactor) bool { - if len(l) != len(o) { - return false - } - if (l == nil) != (o == nil) { - return false - } - for i := 0; i < len(l); i++ { - if l[i] != o[i] { - return false - } - } - return true -} - -// Swap swaps terms in the Variable (implements Sort interface) -func (l LinearExpressionToRefactor) Swap(i, j int) { - l[i], l[j] = l[j], l[i] -} - -// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) -func (l LinearExpressionToRefactor) Less(i, j int) bool { - iID := l[i].WireID() - jID := l[j].WireID() - return iID < jID -} - -// HashCode returns a fast-to-compute but NOT collision resistant hash code identifier for the linear -// expression -func (l LinearExpressionToRefactor) HashCode() uint64 { - h := uint64(17) - for _, val := range l { - h = h*23 + val.HashCode() // TODO @gbotrel revisit - } - return h -} diff --git a/frontend/internal/expr/term.go b/frontend/internal/expr/term.go new file mode 100644 index 0000000000..5c9ba5877c --- /dev/null +++ b/frontend/internal/expr/term.go @@ -0,0 +1,25 @@ +package expr + +import "github.com/consensys/gnark/constraint" + +type Term struct { + VID int + Coeff constraint.Element +} + +func NewTerm(vID int, cID constraint.Element) Term { + return Term{Coeff: cID, VID: vID} +} + +func (t *Term) SetCoeff(c constraint.Element) { + t.Coeff = c +} + +// TODO @gbotrel make that return a uint32 +func (t Term) WireID() int { + return t.VID +} + +func (t Term) HashCode() uint64 { + return t.Coeff[0]*29 + uint64(t.VID<<12) +} diff --git a/frontend/schema/schema.go b/frontend/schema/schema.go index 0f41a7ad22..f2903ba65a 100644 --- a/frontend/schema/schema.go +++ b/frontend/schema/schema.go @@ -361,7 +361,11 @@ func parse(r []Field, input interface{}, target reflect.Type, parentFullName, pa val := tValue.Index(j) if val.CanAddr() && val.Addr().CanInterface() { fqn := getFullName(parentFullName, strconv.Itoa(j), "") - subFields, err = parse(subFields, val.Addr().Interface(), target, fqn, fqn, parentTagName, parentVisibility, nbPublic, nbSecret) + ival := val.Addr().Interface() + if ih, hasInitHook := ival.(InitHook); hasInitHook { + ih.GnarkInitHook() + } + subFields, err = parse(subFields, ival, target, fqn, fqn, parentTagName, parentVisibility, nbPublic, nbSecret) if err != nil { return nil, err } diff --git a/frontend/schema/schema_test.go b/frontend/schema/schema_test.go index 929f1cbb51..23a3bfb27e 100644 --- a/frontend/schema/schema_test.go +++ b/frontend/schema/schema_test.go @@ -165,6 +165,31 @@ func TestSchemaInherit(t *testing.T) { } } +type initableVariable struct { + Val []variable +} + +func (iv *initableVariable) GnarkInitHook() { + if iv.Val == nil { + iv.Val = make([]variable, 2) + } +} + +type initableCircuit struct { + X [2]initableVariable + Y []initableVariable + Z initableVariable +} + +func TestVariableInitHook(t *testing.T) { + assert := require.New(t) + + witness := &initableCircuit{Y: make([]initableVariable, 2)} + s, err := New(witness, tVariable) + assert.NoError(err) + assert.Equal(s.NbSecret, 10) // X: 2*2, Y: 2*2, Z: 2 +} + func BenchmarkLargeSchema(b *testing.B) { const n1 = 1 << 12 const n2 = 1 << 12 diff --git a/frontend/schema/walk.go b/frontend/schema/walk.go index 07d56b880a..e62a18db0e 100644 --- a/frontend/schema/walk.go +++ b/frontend/schema/walk.go @@ -91,11 +91,23 @@ func (w *walker) Slice(value reflect.Value) error { return nil } -func (w *walker) SliceElem(index int, _ reflect.Value) error { +func (w *walker) arraySliceElem(index int, v reflect.Value) error { w.path.push(LeafInfo{Visibility: w.visibility(), name: strconv.Itoa(index)}) + if v.CanAddr() && v.Addr().CanInterface() { + // TODO @gbotrel don't like that hook, undesirable side effects + // will be hard to detect; (for example calling Parse multiple times will init multiple times!) + value := v.Addr().Interface() + if ih, hasInitHook := value.(InitHook); hasInitHook { + ih.GnarkInitHook() + } + } return nil } +func (w *walker) SliceElem(index int, v reflect.Value) error { + return w.arraySliceElem(index, v) +} + // Array handles array elements found within complex structures. func (w *walker) Array(value reflect.Value) error { if value.Type() == reflect.ArrayOf(value.Len(), w.target) { @@ -103,9 +115,8 @@ func (w *walker) Array(value reflect.Value) error { } return nil } -func (w *walker) ArrayElem(index int, _ reflect.Value) error { - w.path.push(LeafInfo{Visibility: w.visibility(), name: strconv.Itoa(index)}) - return nil +func (w *walker) ArrayElem(index int, v reflect.Value) error { + return w.arraySliceElem(index, v) } // process an array or slice of leaves; since it's quite common to have large array/slices diff --git a/frontend/variable.go b/frontend/variable.go index 9f00ba5167..82d33fbc90 100644 --- a/frontend/variable.go +++ b/frontend/variable.go @@ -32,8 +32,6 @@ func IsCanonical(v Variable) bool { switch v.(type) { case expr.LinearExpression, *expr.LinearExpression, expr.Term, *expr.Term: return true - case expr.LinearExpressionToRefactor, *expr.LinearExpressionToRefactor, expr.TermToRefactor, *expr.TermToRefactor: - return true } return false } diff --git a/go.mod b/go.mod index 3a98e61a85..afac14d9ee 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,31 @@ module github.com/consensys/gnark -go 1.18 +go 1.19 require ( - github.com/bits-and-blooms/bitset v1.5.0 + github.com/bits-and-blooms/bitset v1.7.0 github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.13 - github.com/consensys/gnark-crypto v0.9.1 + github.com/consensys/gnark-crypto v0.11.1-0.20230609175512-0ee617fa6d43 github.com/fxamacker/cbor/v2 v2.4.0 github.com/google/go-cmp v0.5.9 - github.com/google/pprof v0.0.0-20230207041349-798e818bf904 + github.com/google/pprof v0.0.0-20230309165930-d61513b1440d github.com/leanovate/gopter v0.2.9 github.com/rs/zerolog v1.29.0 - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.8.2 + golang.org/x/crypto v0.10.0 golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/crypto v0.6.0 // indirect - golang.org/x/sys v0.5.0 // indirect + golang.org/x/sys v0.9.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect diff --git a/go.sum b/go.sum index 05bbfa6f95..cfb27fb3eb 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,11 @@ -github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8= -github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= +github.com/bits-and-blooms/bitset v1.7.0 h1:YjAGVd3XmtK9ktAbX8Zg2g2PwLIMjGREZJHlV4j7NEo= +github.com/bits-and-blooms/bitset v1.7.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.9.1 h1:mru55qKdWl3E035hAoh1jj9d7hVnYY5pfb6tmovSmII= -github.com/consensys/gnark-crypto v0.9.1/go.mod h1:a2DQL4+5ywF6safEeZFEPGRiiGbjzGFRUN2sg06VuU4= +github.com/consensys/gnark-crypto v0.11.1-0.20230609175512-0ee617fa6d43 h1:6VCNdjn2RmxgG2ZklMmSGov9BtCNfVF4VjqAngysiPU= +github.com/consensys/gnark-crypto v0.11.1-0.20230609175512-0ee617fa6d43/go.mod h1:6C2ytC8zmP8uH2GKVfPOjf0Vw3KwMAaUxlCPK5WQqmw= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -16,8 +16,8 @@ github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrt github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= -github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= +github.com/google/pprof v0.0.0-20230309165930-d61513b1440d h1:um9/pc7tKMINFfP1eE7Wv6PRGXlcCSJkVajF7KJw3uQ= +github.com/google/pprof v0.0.0-20230309165930-d61513b1440d/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -28,10 +28,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= @@ -49,21 +51,23 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb h1:PaBZQdo+iSDyHT053FjUCgZQ/9uqVwPOcl7KSWhKn6w= golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integration_test.go b/integration_test.go index 3024e003b2..598d9f9165 100644 --- a/integration_test.go +++ b/integration_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/backend/circuits" "github.com/consensys/gnark/test" ) @@ -55,9 +56,10 @@ func TestIntegrationAPI(t *testing.T) { assert.Run(func(assert *test.Assert) { assert.ProverSucceeded( tData.Circuit, tData.ValidAssignments[i], - test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), + test.WithSolverOpts(solver.WithHints(tData.HintFunctions...)), test.WithCurves(tData.Curves[0], tData.Curves[1:]...), - test.WithBackends(backends[0], backends[1:]...)) + test.WithBackends(backends[0], backends[1:]...), + test.WithSolidity()) }, fmt.Sprintf("valid-%d", i)) } @@ -66,7 +68,7 @@ func TestIntegrationAPI(t *testing.T) { assert.ProverFailed( tData.Circuit, tData.InvalidAssignments[i], - test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), + test.WithSolverOpts(solver.WithHints(tData.HintFunctions...)), test.WithCurves(tData.Curves[0], tData.Curves[1:]...), test.WithBackends(backends[0], backends[1:]...)) }, fmt.Sprintf("invalid-%d", i)) diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go deleted file mode 100644 index 489ab0dc61..0000000000 --- a/internal/backend/bls12-377/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" - "github.com/consensys/gnark/constraint/bls12-377" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go deleted file mode 100644 index e7f0b7a553..0000000000 --- a/internal/backend/bls12-377/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" - "github.com/consensys/gnark/constraint/bls12-377" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go deleted file mode 100644 index 8d1402e02f..0000000000 --- a/internal/backend/bls12-381/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" - "github.com/consensys/gnark/constraint/bls12-381" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go deleted file mode 100644 index e86f945c71..0000000000 --- a/internal/backend/bls12-381/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" - "github.com/consensys/gnark/constraint/bls12-381" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go deleted file mode 100644 index eccaefd544..0000000000 --- a/internal/backend/bls24-315/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" - "github.com/consensys/gnark/constraint/bls24-315" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go deleted file mode 100644 index 0bdfcb2ff9..0000000000 --- a/internal/backend/bls24-315/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" - "github.com/consensys/gnark/constraint/bls24-315" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bls24-317/plonk/prove.go b/internal/backend/bls24-317/plonk/prove.go deleted file mode 100644 index 162b2a6e28..0000000000 --- a/internal/backend/bls24-317/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls24-317" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" - "github.com/consensys/gnark/constraint/bls24-317" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls24-317/plonk/setup.go b/internal/backend/bls24-317/plonk/setup.go deleted file mode 100644 index 08c3689967..0000000000 --- a/internal/backend/bls24-317/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/kzg" - "github.com/consensys/gnark/constraint/bls24-317" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bn254/groth16/utils_test.go b/internal/backend/bn254/groth16/utils_test.go deleted file mode 100644 index 4552aa5e37..0000000000 --- a/internal/backend/bn254/groth16/utils_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package groth16 - -import ( - "testing" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/stretchr/testify/assert" -) - -func assertSliceEquals[T any](t *testing.T, expected []T, seen []T) { - assert.Equal(t, len(expected), len(seen)) - for i := range expected { - assert.Equal(t, expected[i], seen[i]) - } -} - -func TestRemoveIndex(t *testing.T) { - elems := []fr.Element{{0}, {1}, {2}, {3}} - r := filter(elems, []int{1, 2}) - expected := []fr.Element{{0}, {3}} - assertSliceEquals(t, expected, r) -} diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go deleted file mode 100644 index 0f5042889d..0000000000 --- a/internal/backend/bn254/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bn254" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" - "github.com/consensys/gnark/constraint/bn254" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go deleted file mode 100644 index 1cefa3c469..0000000000 --- a/internal/backend/bn254/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" - "github.com/consensys/gnark/constraint/bn254" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bn254/plonk/solidity.go b/internal/backend/bn254/plonk/solidity.go deleted file mode 100644 index 7fddbd204f..0000000000 --- a/internal/backend/bn254/plonk/solidity.go +++ /dev/null @@ -1,907 +0,0 @@ -package plonk - -const solidityTemplate = ` -// Warning this code was contributed into gnark here: -// https://github.com/ConsenSys/gnark/pull/358 -// -// It has not been audited and is provided as-is, we make no guarantees or warranties to its safety and reliability. -// -// According to https://eprint.iacr.org/archive/2019/953/1585767119.pdf -pragma solidity ^0.8.0; -pragma experimental ABIEncoderV2; - -library PairingsBn254 { - uint256 constant q_mod = 21888242871839275222246405745257275088696311157297823662689037894645226208583; - uint256 constant r_mod = 21888242871839275222246405745257275088548364400416034343698204186575808495617; - uint256 constant bn254_b_coeff = 3; - - struct G1Point { - uint256 X; - uint256 Y; - } - - struct Fr { - uint256 value; - } - - function new_fr(uint256 fr) internal pure returns (Fr memory) { - require(fr < r_mod); - return Fr({value: fr}); - } - - function copy(Fr memory self) internal pure returns (Fr memory n) { - n.value = self.value; - } - - function assign(Fr memory self, Fr memory other) internal pure { - self.value = other.value; - } - - function inverse(Fr memory fr) internal view returns (Fr memory) { - require(fr.value != 0); - return pow(fr, r_mod-2); - } - - function add_assign(Fr memory self, Fr memory other) internal pure { - self.value = addmod(self.value, other.value, r_mod); - } - - function sub_assign(Fr memory self, Fr memory other) internal pure { - self.value = addmod(self.value, r_mod - other.value, r_mod); - } - - function mul_assign(Fr memory self, Fr memory other) internal pure { - self.value = mulmod(self.value, other.value, r_mod); - } - - function pow(Fr memory self, uint256 power) internal view returns (Fr memory) { - uint256[6] memory input = [32, 32, 32, self.value, power, r_mod]; - uint256[1] memory result; - bool success; - assembly { - success := staticcall(gas(), 0x05, input, 0xc0, result, 0x20) - } - require(success); - return Fr({value: result[0]}); - } - - // Encoding of field elements is: X[0] * z + X[1] - struct G2Point { - uint[2] X; - uint[2] Y; - } - - function P1() internal pure returns (G1Point memory) { - return G1Point(1, 2); - } - - function new_g1(uint256 x, uint256 y) internal pure returns (G1Point memory) { - return G1Point(x, y); - } - - function new_g1_checked(uint256 x, uint256 y) internal pure returns (G1Point memory) { - if (x == 0 && y == 0) { - // point of infinity is (0,0) - return G1Point(x, y); - } - - // check encoding - require(x < q_mod); - require(y < q_mod); - // check on curve - uint256 lhs = mulmod(y, y, q_mod); // y^2 - uint256 rhs = mulmod(x, x, q_mod); // x^2 - rhs = mulmod(rhs, x, q_mod); // x^3 - rhs = addmod(rhs, bn254_b_coeff, q_mod); // x^3 + b - require(lhs == rhs); - - return G1Point(x, y); - } - - function new_g2(uint256[2] memory x, uint256[2] memory y) internal pure returns (G2Point memory) { - return G2Point(x, y); - } - - function copy_g1(G1Point memory self) internal pure returns (G1Point memory result) { - result.X = self.X; - result.Y = self.Y; - } - - function P2() internal pure returns (G2Point memory) { - // for some reason ethereum expects to have c1*v + c0 form - - return G2Point( - [0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2, - 0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed], - [0x090689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b, - 0x12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa] - ); - } - - function negate(G1Point memory self) internal pure { - // The prime q in the base field F_q for G1 - if (self.Y == 0) { - require(self.X == 0); - return; - } - - self.Y = q_mod - self.Y; - } - - function point_add(G1Point memory p1, G1Point memory p2) - internal view returns (G1Point memory r) - { - point_add_into_dest(p1, p2, r); - return r; - } - - function point_add_assign(G1Point memory p1, G1Point memory p2) - internal view - { - point_add_into_dest(p1, p2, p1); - } - - function point_add_into_dest(G1Point memory p1, G1Point memory p2, G1Point memory dest) - internal view - { - if (p2.X == 0 && p2.Y == 0) { - // we add zero, nothing happens - dest.X = p1.X; - dest.Y = p1.Y; - return; - } else if (p1.X == 0 && p1.Y == 0) { - // we add into zero, and we add non-zero point - dest.X = p2.X; - dest.Y = p2.Y; - return; - } else { - uint256[4] memory input; - - input[0] = p1.X; - input[1] = p1.Y; - input[2] = p2.X; - input[3] = p2.Y; - - bool success = false; - assembly { - success := staticcall(gas(), 6, input, 0x80, dest, 0x40) - } - require(success); - } - } - - function point_sub_assign(G1Point memory p1, G1Point memory p2) - internal view - { - point_sub_into_dest(p1, p2, p1); - } - - function point_sub_into_dest(G1Point memory p1, G1Point memory p2, G1Point memory dest) - internal view - { - if (p2.X == 0 && p2.Y == 0) { - // we subtracted zero, nothing happens - dest.X = p1.X; - dest.Y = p1.Y; - return; - } else if (p1.X == 0 && p1.Y == 0) { - // we subtract from zero, and we subtract non-zero point - dest.X = p2.X; - dest.Y = q_mod - p2.Y; - return; - } else { - uint256[4] memory input; - - input[0] = p1.X; - input[1] = p1.Y; - input[2] = p2.X; - input[3] = q_mod - p2.Y; - - bool success = false; - assembly { - success := staticcall(gas(), 6, input, 0x80, dest, 0x40) - } - require(success); - } - } - - function point_mul(G1Point memory p, Fr memory s) - internal view returns (G1Point memory r) - { - point_mul_into_dest(p, s, r); - return r; - } - - function point_mul_assign(G1Point memory p, Fr memory s) - internal view - { - point_mul_into_dest(p, s, p); - } - - function point_mul_into_dest(G1Point memory p, Fr memory s, G1Point memory dest) - internal view - { - uint[3] memory input; - input[0] = p.X; - input[1] = p.Y; - input[2] = s.value; - bool success; - assembly { - success := staticcall(gas(), 7, input, 0x60, dest, 0x40) - } - require(success); - } - - function pairing(G1Point[] memory p1, G2Point[] memory p2) - internal view returns (bool) - { - require(p1.length == p2.length); - uint elements = p1.length; - uint inputSize = elements * 6; - uint[] memory input = new uint[](inputSize); - for (uint i = 0; i < elements; i++) - { - input[i * 6 + 0] = p1[i].X; - input[i * 6 + 1] = p1[i].Y; - input[i * 6 + 2] = p2[i].X[0]; - input[i * 6 + 3] = p2[i].X[1]; - input[i * 6 + 4] = p2[i].Y[0]; - input[i * 6 + 5] = p2[i].Y[1]; - } - uint[1] memory out; - bool success; - assembly { - success := staticcall(gas(), 8, add(input, 0x20), mul(inputSize, 0x20), out, 0x20) - } - require(success); - return out[0] != 0; - } - - /// Convenience method for a pairing check for two pairs. - function pairingProd2(G1Point memory a1, G2Point memory a2, G1Point memory b1, G2Point memory b2) - internal view returns (bool) - { - G1Point[] memory p1 = new G1Point[](2); - G2Point[] memory p2 = new G2Point[](2); - p1[0] = a1; - p1[1] = b1; - p2[0] = a2; - p2[1] = b2; - return pairing(p1, p2); - } -} - -library TranscriptLibrary { - uint32 constant DST_0 = 0; - uint32 constant DST_1 = 1; - uint32 constant DST_CHALLENGE = 2; - - struct Transcript { - bytes32 previous_randomness; - bytes bindings; - string name; - uint32 challenge_counter; - } - - function new_transcript() internal pure returns (Transcript memory t) { - t.challenge_counter = 0; - } - - function set_challenge_name(Transcript memory self, string memory name) internal pure { - self.name = name; - } - - function update_with_u256(Transcript memory self, uint256 value) internal pure { - self.bindings = abi.encodePacked(self.bindings, value); - } - - function update_with_fr(Transcript memory self, PairingsBn254.Fr memory value) internal pure { - self.bindings = abi.encodePacked(self.bindings, value.value); - } - - function update_with_g1(Transcript memory self, PairingsBn254.G1Point memory p) internal pure { - self.bindings = abi.encodePacked(self.bindings, p.X, p.Y); - } - - function get_encode(Transcript memory self) internal pure returns(bytes memory query) { - if (self.challenge_counter != 0) { - query = abi.encodePacked(self.name, self.previous_randomness, self.bindings); - } else { - query = abi.encodePacked(self.name, self.bindings); - } - return query; - } - function get_challenge(Transcript memory self) internal pure returns(PairingsBn254.Fr memory challenge) { - bytes32 query; - if (self.challenge_counter != 0) { - query = sha256(abi.encodePacked(self.name, self.previous_randomness, self.bindings)); - } else { - query = sha256(abi.encodePacked(self.name, self.bindings)); - } - self.challenge_counter += 1; - self.previous_randomness = query; - challenge = PairingsBn254.Fr({value: uint256(query) % PairingsBn254.r_mod}); - self.bindings = ""; - } -} - -contract PlonkVerifier { - using PairingsBn254 for PairingsBn254.G1Point; - using PairingsBn254 for PairingsBn254.G2Point; - using PairingsBn254 for PairingsBn254.Fr; - - using TranscriptLibrary for TranscriptLibrary.Transcript; - - uint256 constant STATE_WIDTH = 3; - - struct VerificationKey { - uint256 domain_size; - uint256 num_inputs; - PairingsBn254.Fr omega; // w - PairingsBn254.G1Point[STATE_WIDTH+2] selector_commitments; // STATE_WIDTH for witness + multiplication + constant - PairingsBn254.G1Point[STATE_WIDTH] permutation_commitments; // [Sσ1(x)],[Sσ2(x)],[Sσ3(x)] - PairingsBn254.Fr[STATE_WIDTH-1] permutation_non_residues; // k1, k2 - PairingsBn254.G2Point g2_x; - } - - struct Proof { - uint256[] input_values; - PairingsBn254.G1Point[STATE_WIDTH] wire_commitments; // [a(x)]/[b(x)]/[c(x)] - PairingsBn254.G1Point grand_product_commitment; // [z(x)] - PairingsBn254.G1Point[STATE_WIDTH] quotient_poly_commitments; // [t_lo]/[t_mid]/[t_hi] - PairingsBn254.Fr[STATE_WIDTH] wire_values_at_zeta; // a(zeta)/b(zeta)/c(zeta) - PairingsBn254.Fr grand_product_at_zeta_omega; // z(w*zeta) - PairingsBn254.Fr quotient_polynomial_at_zeta; // t(zeta) - PairingsBn254.Fr linearization_polynomial_at_zeta; // r(zeta) - PairingsBn254.Fr[STATE_WIDTH-1] permutation_polynomials_at_zeta; // Sσ1(zeta),Sσ2(zeta) - - PairingsBn254.G1Point opening_at_zeta_proof; // [Wzeta] - PairingsBn254.G1Point opening_at_zeta_omega_proof; // [Wzeta*omega] - } - - struct PartialVerifierState { - PairingsBn254.Fr alpha; - PairingsBn254.Fr beta; - PairingsBn254.Fr gamma; - PairingsBn254.Fr v; - PairingsBn254.Fr u; - PairingsBn254.Fr zeta; - PairingsBn254.Fr[] cached_lagrange_evals; - - PairingsBn254.G1Point cached_fold_quotient_ploy_commitments; - } - - function verify_initial( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk) internal view returns (bool) { - - require(proof.input_values.length == vk.num_inputs, "not match"); - require(vk.num_inputs >= 1, "inv input"); - - TranscriptLibrary.Transcript memory t = TranscriptLibrary.new_transcript(); - t.set_challenge_name("gamma"); - for (uint256 i = 0; i < vk.permutation_commitments.length; i++) { - t.update_with_g1(vk.permutation_commitments[i]); - } - // this is gnark order: Ql, Qr, Qm, Qo, Qk - // - t.update_with_g1(vk.selector_commitments[0]); - t.update_with_g1(vk.selector_commitments[1]); - t.update_with_g1(vk.selector_commitments[3]); - t.update_with_g1(vk.selector_commitments[2]); - t.update_with_g1(vk.selector_commitments[4]); - - for (uint256 i = 0; i < proof.input_values.length; i++) { - t.update_with_u256(proof.input_values[i]); - } - state.gamma = t.get_challenge(); - - t.set_challenge_name("beta"); - state.beta = t.get_challenge(); - - t.set_challenge_name("alpha"); - t.update_with_g1(proof.grand_product_commitment); - state.alpha = t.get_challenge(); - - t.set_challenge_name("zeta"); - for (uint256 i = 0; i < proof.quotient_poly_commitments.length; i++) { - t.update_with_g1(proof.quotient_poly_commitments[i]); - } - state.zeta = t.get_challenge(); - - uint256[] memory lagrange_poly_numbers = new uint256[](vk.num_inputs); - for (uint256 i = 0; i < lagrange_poly_numbers.length; i++) { - lagrange_poly_numbers[i] = i; - } - state.cached_lagrange_evals = batch_evaluate_lagrange_poly_out_of_domain( - lagrange_poly_numbers, - vk.domain_size, - vk.omega, state.zeta - ); - - bool valid = verify_quotient_poly_eval_at_zeta(state, proof, vk); - return valid; - } - - function verify_commitments( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk - ) internal view returns (bool) { - PairingsBn254.G1Point memory d = reconstruct_d(state, proof, vk); - - PairingsBn254.G1Point memory tmp_g1 = PairingsBn254.P1(); - - PairingsBn254.Fr memory aggregation_challenge = PairingsBn254.new_fr(1); - - PairingsBn254.G1Point memory commitment_aggregation = PairingsBn254.copy_g1(state.cached_fold_quotient_ploy_commitments); - PairingsBn254.Fr memory tmp_fr = PairingsBn254.new_fr(1); - - aggregation_challenge.mul_assign(state.v); - commitment_aggregation.point_add_assign(d); - - for (uint i = 0; i < proof.wire_commitments.length; i++) { - aggregation_challenge.mul_assign(state.v); - tmp_g1 = proof.wire_commitments[i].point_mul(aggregation_challenge); - commitment_aggregation.point_add_assign(tmp_g1); - } - - for (uint i = 0; i < vk.permutation_commitments.length - 1; i++) { - aggregation_challenge.mul_assign(state.v); - tmp_g1 = vk.permutation_commitments[i].point_mul(aggregation_challenge); - commitment_aggregation.point_add_assign(tmp_g1); - } - - // collect opening values - aggregation_challenge = PairingsBn254.new_fr(1); - - PairingsBn254.Fr memory aggregated_value = PairingsBn254.copy(proof.quotient_polynomial_at_zeta); - - aggregation_challenge.mul_assign(state.v); - - tmp_fr.assign(proof.linearization_polynomial_at_zeta); - tmp_fr.mul_assign(aggregation_challenge); - aggregated_value.add_assign(tmp_fr); - - for (uint i = 0; i < proof.wire_values_at_zeta.length; i++) { - aggregation_challenge.mul_assign(state.v); - - tmp_fr.assign(proof.wire_values_at_zeta[i]); - tmp_fr.mul_assign(aggregation_challenge); - aggregated_value.add_assign(tmp_fr); - } - - for (uint i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - aggregation_challenge.mul_assign(state.v); - - tmp_fr.assign(proof.permutation_polynomials_at_zeta[i]); - tmp_fr.mul_assign(aggregation_challenge); - aggregated_value.add_assign(tmp_fr); - } - tmp_fr.assign(proof.grand_product_at_zeta_omega); - tmp_fr.mul_assign(state.u); - aggregated_value.add_assign(tmp_fr); - - commitment_aggregation.point_sub_assign(PairingsBn254.P1().point_mul(aggregated_value)); - - PairingsBn254.G1Point memory pair_with_generator = commitment_aggregation; - pair_with_generator.point_add_assign(proof.opening_at_zeta_proof.point_mul(state.zeta)); - - tmp_fr.assign(state.zeta); - tmp_fr.mul_assign(vk.omega); - tmp_fr.mul_assign(state.u); - pair_with_generator.point_add_assign(proof.opening_at_zeta_omega_proof.point_mul(tmp_fr)); - - PairingsBn254.G1Point memory pair_with_x = proof.opening_at_zeta_omega_proof.point_mul(state.u); - pair_with_x.point_add_assign(proof.opening_at_zeta_proof); - pair_with_x.negate(); - - return PairingsBn254.pairingProd2(pair_with_generator, PairingsBn254.P2(), pair_with_x, vk.g2_x); - } - - function reconstruct_d( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk - ) internal view returns (PairingsBn254.G1Point memory res) { - res = PairingsBn254.copy_g1(vk.selector_commitments[STATE_WIDTH + 1]); - - PairingsBn254.G1Point memory tmp_g1 = PairingsBn254.P1(); - PairingsBn254.Fr memory tmp_fr = PairingsBn254.new_fr(0); - - // addition gates - for (uint256 i = 0; i < STATE_WIDTH; i++) { - tmp_g1 = vk.selector_commitments[i].point_mul(proof.wire_values_at_zeta[i]); - res.point_add_assign(tmp_g1); - } - - // multiplication gate - tmp_fr.assign(proof.wire_values_at_zeta[0]); - tmp_fr.mul_assign(proof.wire_values_at_zeta[1]); - tmp_g1 = vk.selector_commitments[STATE_WIDTH].point_mul(tmp_fr); - res.point_add_assign(tmp_g1); - - // z * non_res * beta + gamma + a - PairingsBn254.Fr memory grand_product_part_at_z = PairingsBn254.copy(state.zeta); - grand_product_part_at_z.mul_assign(state.beta); - grand_product_part_at_z.add_assign(proof.wire_values_at_zeta[0]); - grand_product_part_at_z.add_assign(state.gamma); - for (uint256 i = 0; i < vk.permutation_non_residues.length; i++) { - tmp_fr.assign(state.zeta); - tmp_fr.mul_assign(vk.permutation_non_residues[i]); - tmp_fr.mul_assign(state.beta); - tmp_fr.add_assign(state.gamma); - tmp_fr.add_assign(proof.wire_values_at_zeta[i+1]); - - grand_product_part_at_z.mul_assign(tmp_fr); - } - - grand_product_part_at_z.mul_assign(state.alpha); - - tmp_fr.assign(state.cached_lagrange_evals[0]); - tmp_fr.mul_assign(state.alpha); - tmp_fr.mul_assign(state.alpha); - // NOTICE - grand_product_part_at_z.sub_assign(tmp_fr); - PairingsBn254.Fr memory last_permutation_part_at_z = PairingsBn254.new_fr(1); - for (uint256 i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - tmp_fr.assign(state.beta); - tmp_fr.mul_assign(proof.permutation_polynomials_at_zeta[i]); - tmp_fr.add_assign(state.gamma); - tmp_fr.add_assign(proof.wire_values_at_zeta[i]); - - last_permutation_part_at_z.mul_assign(tmp_fr); - } - - last_permutation_part_at_z.mul_assign(state.beta); - last_permutation_part_at_z.mul_assign(proof.grand_product_at_zeta_omega); - last_permutation_part_at_z.mul_assign(state.alpha); - - // gnark implementation: add third part and sub second second part - // plonk paper implementation: add second part and sub third part - /* - tmp_g1 = proof.grand_product_commitment.point_mul(grand_product_part_at_z); - tmp_g1.point_sub_assign(vk.permutation_commitments[STATE_WIDTH - 1].point_mul(last_permutation_part_at_z)); - */ - // add to the linearization - - tmp_g1 = vk.permutation_commitments[STATE_WIDTH - 1].point_mul(last_permutation_part_at_z); - tmp_g1.point_sub_assign(proof.grand_product_commitment.point_mul(grand_product_part_at_z)); - res.point_add_assign(tmp_g1); - - generate_uv_challenge(state, proof, vk, res); - - res.point_mul_assign(state.v); - res.point_add_assign(proof.grand_product_commitment.point_mul(state.u)); - } - - // gnark v generation process: - // sha256(zeta, proof.quotient_poly_commitments, linearizedPolynomialDigest, proof.wire_commitments, vk.permutation_commitments[0..1], ) - // NOTICE: gnark use "gamma" name for v, it's not reasonable - // NOTICE: gnark use zeta^(n+2) which is a bit different with plonk paper - // generate_v_challenge(); - function generate_uv_challenge( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk, - PairingsBn254.G1Point memory linearization_point) view internal { - TranscriptLibrary.Transcript memory transcript = TranscriptLibrary.new_transcript(); - transcript.set_challenge_name("gamma"); - transcript.update_with_fr(state.zeta); - PairingsBn254.Fr memory zeta_plus_two = PairingsBn254.copy(state.zeta); - PairingsBn254.Fr memory n_plus_two = PairingsBn254.new_fr(vk.domain_size); - n_plus_two.add_assign(PairingsBn254.new_fr(2)); - zeta_plus_two = zeta_plus_two.pow(n_plus_two.value); - state.cached_fold_quotient_ploy_commitments = PairingsBn254.copy_g1(proof.quotient_poly_commitments[STATE_WIDTH-1]); - for (uint256 i = 0; i < STATE_WIDTH - 1; i++) { - state.cached_fold_quotient_ploy_commitments.point_mul_assign(zeta_plus_two); - state.cached_fold_quotient_ploy_commitments.point_add_assign(proof.quotient_poly_commitments[STATE_WIDTH - 2 - i]); - } - transcript.update_with_g1(state.cached_fold_quotient_ploy_commitments); - transcript.update_with_g1(linearization_point); - - for (uint256 i = 0; i < proof.wire_commitments.length; i++) { - transcript.update_with_g1(proof.wire_commitments[i]); - } - for (uint256 i = 0; i < vk.permutation_commitments.length - 1; i++) { - transcript.update_with_g1(vk.permutation_commitments[i]); - } - state.v = transcript.get_challenge(); - // gnark use local randomness to generate u - // we use opening_at_zeta_proof and opening_at_zeta_omega_proof - transcript.set_challenge_name("u"); - transcript.update_with_g1(proof.opening_at_zeta_proof); - transcript.update_with_g1(proof.opening_at_zeta_omega_proof); - state.u = transcript.get_challenge(); - } - - function batch_evaluate_lagrange_poly_out_of_domain( - uint256[] memory poly_nums, - uint256 domain_size, - PairingsBn254.Fr memory omega, - PairingsBn254.Fr memory at - ) internal view returns (PairingsBn254.Fr[] memory res) { - PairingsBn254.Fr memory one = PairingsBn254.new_fr(1); - PairingsBn254.Fr memory tmp_1 = PairingsBn254.new_fr(0); - PairingsBn254.Fr memory tmp_2 = PairingsBn254.new_fr(domain_size); - PairingsBn254.Fr memory vanishing_at_zeta = at.pow(domain_size); - vanishing_at_zeta.sub_assign(one); - // we can not have random point z be in domain - require(vanishing_at_zeta.value != 0); - PairingsBn254.Fr[] memory nums = new PairingsBn254.Fr[](poly_nums.length); - PairingsBn254.Fr[] memory dens = new PairingsBn254.Fr[](poly_nums.length); - // numerators in a form omega^i * (z^n - 1) - // denoms in a form (z - omega^i) * N - for (uint i = 0; i < poly_nums.length; i++) { - tmp_1 = omega.pow(poly_nums[i]); // power of omega - nums[i].assign(vanishing_at_zeta); - nums[i].mul_assign(tmp_1); - - dens[i].assign(at); // (X - omega^i) * N - dens[i].sub_assign(tmp_1); - dens[i].mul_assign(tmp_2); // mul by domain size - } - - PairingsBn254.Fr[] memory partial_products = new PairingsBn254.Fr[](poly_nums.length); - partial_products[0].assign(PairingsBn254.new_fr(1)); - for (uint i = 1; i < dens.length; i++) { - partial_products[i].assign(dens[i-1]); - partial_products[i].mul_assign(partial_products[i-1]); - } - - tmp_2.assign(partial_products[partial_products.length - 1]); - tmp_2.mul_assign(dens[dens.length - 1]); - tmp_2 = tmp_2.inverse(); // tmp_2 contains a^-1 * b^-1 (with! the last one) - - for (uint i = dens.length; i > 0; i--) { - tmp_1.assign(tmp_2); // all inversed - tmp_1.mul_assign(partial_products[i-1]); // clear lowest terms - tmp_2.mul_assign(dens[i-1]); - dens[i-1].assign(tmp_1); - } - - for (uint i = 0; i < nums.length; i++) { - nums[i].mul_assign(dens[i]); - } - - return nums; - } - - // plonk paper verify process step8: Compute quotient polynomial evaluation - function verify_quotient_poly_eval_at_zeta( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk - ) internal view returns (bool) { - PairingsBn254.Fr memory lhs = evaluate_vanishing(vk.domain_size, state.zeta); - require(lhs.value != 0); // we can not check a polynomial relationship if point z is in the domain - lhs.mul_assign(proof.quotient_polynomial_at_zeta); - - PairingsBn254.Fr memory quotient_challenge = PairingsBn254.new_fr(1); - PairingsBn254.Fr memory rhs = PairingsBn254.copy(proof.linearization_polynomial_at_zeta); - - // public inputs - PairingsBn254.Fr memory tmp = PairingsBn254.new_fr(0); - for (uint256 i = 0; i < proof.input_values.length; i++) { - tmp.assign(state.cached_lagrange_evals[i]); - tmp.mul_assign(PairingsBn254.new_fr(proof.input_values[i])); - rhs.add_assign(tmp); - } - - quotient_challenge.mul_assign(state.alpha); - - PairingsBn254.Fr memory z_part = PairingsBn254.copy(proof.grand_product_at_zeta_omega); - for (uint256 i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - tmp.assign(proof.permutation_polynomials_at_zeta[i]); - tmp.mul_assign(state.beta); - tmp.add_assign(state.gamma); - tmp.add_assign(proof.wire_values_at_zeta[i]); - - z_part.mul_assign(tmp); - } - - tmp.assign(state.gamma); - // we need a wire value of the last polynomial in enumeration - tmp.add_assign(proof.wire_values_at_zeta[STATE_WIDTH - 1]); - - z_part.mul_assign(tmp); - z_part.mul_assign(quotient_challenge); - - // NOTICE: this is different with plonk paper - // plonk paper should be: rhs.sub_assign(z_part); - rhs.add_assign(z_part); - - quotient_challenge.mul_assign(state.alpha); - - tmp.assign(state.cached_lagrange_evals[0]); - tmp.mul_assign(quotient_challenge); - - rhs.sub_assign(tmp); - - return lhs.value == rhs.value; - } - - function evaluate_vanishing( - uint256 domain_size, - PairingsBn254.Fr memory at - ) internal view returns (PairingsBn254.Fr memory res) { - res = at.pow(domain_size); - res.sub_assign(PairingsBn254.new_fr(1)); - } - - // This verifier is for a PLONK with a state width 3 - // and main gate equation - // q_a(X) * a(X) + - // q_b(X) * b(X) + - // q_c(X) * c(X) + - // q_m(X) * a(X) * b(X) + - // q_constants(X)+ - // where q_{}(X) are selectors a, b, c - state (witness) polynomials - - function verify(Proof memory proof, VerificationKey memory vk) internal view returns (bool) { - PartialVerifierState memory state; - - bool valid = verify_initial(state, proof, vk); - - if (valid == false) { - return false; - } - - valid = verify_commitments(state, proof, vk); - - return valid; - } -} - -contract KeyedPlonkVerifier is PlonkVerifier { - uint256 constant SERIALIZED_PROOF_LENGTH = 26; - using PairingsBn254 for PairingsBn254.Fr; - function get_verification_key() internal pure returns(VerificationKey memory vk) { - vk.domain_size = {{.Size}}; - vk.num_inputs = {{.NbPublicVariables}}; - vk.omega = PairingsBn254.new_fr(uint256({{.Generator.String}})); - vk.selector_commitments[0] = PairingsBn254.new_g1( - uint256({{.Ql.X.String}}), - uint256({{.Ql.Y.String}}) - ); - vk.selector_commitments[1] = PairingsBn254.new_g1( - uint256({{.Qr.X.String}}), - uint256({{.Qr.Y.String}}) - ); - vk.selector_commitments[2] = PairingsBn254.new_g1( - uint256({{.Qo.X.String}}), - uint256({{.Qo.Y.String}}) - ); - vk.selector_commitments[3] = PairingsBn254.new_g1( - uint256({{.Qm.X.String}}), - uint256({{.Qm.Y.String}}) - ); - vk.selector_commitments[4] = PairingsBn254.new_g1( - uint256({{.Qk.X.String}}), - uint256({{.Qk.Y.String}}) - ); - - vk.permutation_commitments[0] = PairingsBn254.new_g1( - uint256({{(index .S 0).X.String}}), - uint256({{(index .S 0).Y.String}}) - ); - vk.permutation_commitments[1] = PairingsBn254.new_g1( - uint256({{(index .S 1).X.String}}), - uint256({{(index .S 1).Y.String}}) - ); - vk.permutation_commitments[2] = PairingsBn254.new_g1( - uint256({{(index .S 2).X.String}}), - uint256({{(index .S 2).Y.String}}) - ); - - vk.permutation_non_residues[0] = PairingsBn254.new_fr( - uint256({{.CosetShift.String}}) - ); - vk.permutation_non_residues[1] = PairingsBn254.copy( - vk.permutation_non_residues[0] - ); - vk.permutation_non_residues[1].mul_assign(vk.permutation_non_residues[0]); - - vk.g2_x = PairingsBn254.new_g2( - [uint256({{(index .KZGSRS.G2 1).X.A1.String}}), - uint256({{(index .KZGSRS.G2 1).X.A0.String}})], - [uint256({{(index .KZGSRS.G2 1).Y.A1.String}}), - uint256({{(index .KZGSRS.G2 1).Y.A0.String}})] - ); - } - - - function deserialize_proof( - uint256[] memory public_inputs, - uint256[] memory serialized_proof - ) internal pure returns(Proof memory proof) { - require(serialized_proof.length == SERIALIZED_PROOF_LENGTH); - proof.input_values = new uint256[](public_inputs.length); - for (uint256 i = 0; i < public_inputs.length; i++) { - proof.input_values[i] = public_inputs[i]; - } - - uint256 j = 0; - for (uint256 i = 0; i < STATE_WIDTH; i++) { - proof.wire_commitments[i] = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - - j += 2; - } - - proof.grand_product_commitment = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - j += 2; - - for (uint256 i = 0; i < STATE_WIDTH; i++) { - proof.quotient_poly_commitments[i] = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - - j += 2; - } - - for (uint256 i = 0; i < STATE_WIDTH; i++) { - proof.wire_values_at_zeta[i] = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - } - - proof.grand_product_at_zeta_omega = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - - proof.quotient_polynomial_at_zeta = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - - proof.linearization_polynomial_at_zeta = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - - for (uint256 i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - proof.permutation_polynomials_at_zeta[i] = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - } - - proof.opening_at_zeta_proof = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - j += 2; - - proof.opening_at_zeta_omega_proof = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - } - - function verify_serialized_proof( - uint256[] memory public_inputs, - uint256[] memory serialized_proof - ) public view returns (bool) { - VerificationKey memory vk = get_verification_key(); - require(vk.num_inputs == public_inputs.length); - Proof memory proof = deserialize_proof(public_inputs, serialized_proof); - bool valid = verify(proof, vk); - return valid; - } -} -` diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go deleted file mode 100644 index bd6f3ff24b..0000000000 --- a/internal/backend/bw6-633/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" - "github.com/consensys/gnark/constraint/bw6-633" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go deleted file mode 100644 index 4e6d650887..0000000000 --- a/internal/backend/bw6-633/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" - "github.com/consensys/gnark/constraint/bw6-633" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go deleted file mode 100644 index 5a372eb87b..0000000000 --- a/internal/backend/bw6-761/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" - "github.com/consensys/gnark/constraint/bw6-761" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go deleted file mode 100644 index 096aba6f95..0000000000 --- a/internal/backend/bw6-761/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" - "github.com/consensys/gnark/constraint/bw6-761" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/circuits/circuits.go b/internal/backend/circuits/circuits.go index 51bdb7b9a9..92b306de73 100644 --- a/internal/backend/circuits/circuits.go +++ b/internal/backend/circuits/circuits.go @@ -3,7 +3,8 @@ package circuits import ( "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -11,7 +12,7 @@ import ( type TestCircuit struct { Circuit frontend.Circuit ValidAssignments, InvalidAssignments []frontend.Circuit // good and bad witness for the prover + public verifier data - HintFunctions []hint.Function + HintFunctions []solver.Hint Curves []ecc.ID } @@ -30,13 +31,14 @@ func addEntry(name string, circuit, proverGood, proverBad frontend.Circuit, curv Circuits[name] = TestCircuit{circuit, []frontend.Circuit{proverGood}, []frontend.Circuit{proverBad}, nil, curves} } -func addNewEntry(name string, circuit frontend.Circuit, proverGood, proverBad []frontend.Circuit, curves []ecc.ID, hintFunctions ...hint.Function) { +func addNewEntry(name string, circuit frontend.Circuit, proverGood, proverBad []frontend.Circuit, curves []ecc.ID, hintFunctions ...solver.Hint) { if Circuits == nil { Circuits = make(map[string]TestCircuit) } if _, ok := Circuits[name]; ok { panic("name " + name + "already taken by another test circuit ") } + solver.RegisterHint(hintFunctions...) Circuits[name] = TestCircuit{circuit, proverGood, proverBad, hintFunctions, curves} } diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index 5ed7a009de..1542438d5f 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -91,7 +91,7 @@ func init() { }, } - addNewEntry("recursive_hint", &recursiveHint{}, good, bad, gnark.Curves(), make3, bits.NBits) + addNewEntry("recursive_hint", &recursiveHint{}, good, bad, gnark.Curves(), make3, bits.GetHints()[1]) } { diff --git a/internal/circuitdefer/defer.go b/internal/circuitdefer/defer.go new file mode 100644 index 0000000000..4bda7383e7 --- /dev/null +++ b/internal/circuitdefer/defer.go @@ -0,0 +1,43 @@ +package circuitdefer + +import ( + "github.com/consensys/gnark/internal/kvstore" +) + +type deferKey struct{} + +func Put[T any](builder any, cb T) { + // we use generics for type safety but to avoid import cycles. + // TODO: compare with using any and type asserting at caller + kv, ok := builder.(kvstore.Store) + if !ok { + panic("builder does not implement kvstore.Store") + } + val := kv.GetKeyValue(deferKey{}) + var deferred []T + if val != nil { + var ok bool + deferred, ok = val.([]T) + if !ok { + panic("stored deferred functions not []func(frontend.API) error") + } + } + deferred = append(deferred, cb) + kv.SetKeyValue(deferKey{}, deferred) +} + +func GetAll[T any](builder any) []T { + kv, ok := builder.(kvstore.Store) + if !ok { + panic("builder does not implement kvstore.Store") + } + val := kv.GetKeyValue(deferKey{}) + if val == nil { + return nil + } + deferred, ok := val.([]T) + if !ok { + panic("stored deferred functions not []func(frontend.API) error") + } + return deferred +} diff --git a/internal/frontendtype/frontendtype.go b/internal/frontendtype/frontendtype.go new file mode 100644 index 0000000000..040af53c5b --- /dev/null +++ b/internal/frontendtype/frontendtype.go @@ -0,0 +1,13 @@ +// Package frontendtype allows to assert frontend type. +package frontendtype + +type Type int + +const ( + R1CS Type = iota + SCS +) + +type FrontendTyper interface { + FrontendType() Type +} diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 8b6b03d055..b5c0bfe1a8 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -4,6 +4,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "github.com/consensys/bavard" @@ -19,43 +20,43 @@ var bgen = bavard.NewBatchGenerator(copyrightHolder, 2020, "gnark") func main() { bls12_377 := templateData{ - RootPath: "../../../internal/backend/bls12-377/", + RootPath: "../../../backend/{?}/bls12-377/", CSPath: "../../../constraint/bls12-377/", Curve: "BLS12-377", CurveID: "BLS12_377", } bls12_381 := templateData{ - RootPath: "../../../internal/backend/bls12-381/", + RootPath: "../../../backend/{?}/bls12-381/", CSPath: "../../../constraint/bls12-381/", Curve: "BLS12-381", CurveID: "BLS12_381", } bn254 := templateData{ - RootPath: "../../../internal/backend/bn254/", + RootPath: "../../../backend/{?}/bn254/", CSPath: "../../../constraint/bn254/", Curve: "BN254", CurveID: "BN254", } bw6_761 := templateData{ - RootPath: "../../../internal/backend/bw6-761/", + RootPath: "../../../backend/{?}/bw6-761/", CSPath: "../../../constraint/bw6-761/", Curve: "BW6-761", CurveID: "BW6_761", } bls24_315 := templateData{ - RootPath: "../../../internal/backend/bls24-315/", + RootPath: "../../../backend/{?}/bls24-315/", CSPath: "../../../constraint/bls24-315/", Curve: "BLS24-315", CurveID: "BLS24_315", } bls24_317 := templateData{ - RootPath: "../../../internal/backend/bls24-317/", + RootPath: "../../../backend/{?}/bls24-317/", CSPath: "../../../constraint/bls24-317/", Curve: "BLS24-317", CurveID: "BLS24_317", } bw6_633 := templateData{ - RootPath: "../../../internal/backend/bw6-633/", + RootPath: "../../../backend/{?}/bw6-633/", CSPath: "../../../constraint/bw6-633/", Curve: "BW6-633", CurveID: "BW6_633", @@ -97,16 +98,22 @@ func main() { wg.Add(1) go func(d templateData) { - defer wg.Done() - if err := os.MkdirAll(d.RootPath+"groth16", 0700); err != nil { + var ( + groth16Dir = strings.Replace(d.RootPath, "{?}", "groth16", 1) + groth16MpcSetupDir = filepath.Join(groth16Dir, "mpcsetup") + plonkDir = strings.Replace(d.RootPath, "{?}", "plonk", 1) + plonkFriDir = strings.Replace(d.RootPath, "{?}", "plonkfri", 1) + ) + + if err := os.MkdirAll(groth16Dir, 0700); err != nil { panic(err) } - if err := os.MkdirAll(d.RootPath+"plonk", 0700); err != nil { + if err := os.MkdirAll(plonkDir, 0700); err != nil { panic(err) } - if err := os.MkdirAll(d.RootPath+"plonkfri", 0700); err != nil { + if err := os.MkdirAll(plonkFriDir, 0700); err != nil { panic(err) } @@ -114,10 +121,9 @@ func main() { // constraint systems entries := []bavard.Entry{ - {File: filepath.Join(csDir, "r1cs.go"), Templates: []string{"r1cs.go.tmpl", importCurve}}, + {File: filepath.Join(csDir, "system.go"), Templates: []string{"system.go.tmpl", importCurve}}, {File: filepath.Join(csDir, "coeff.go"), Templates: []string{"coeff.go.tmpl", importCurve}}, - {File: filepath.Join(csDir, "r1cs_sparse.go"), Templates: []string{"r1cs.sparse.go.tmpl", importCurve}}, - {File: filepath.Join(csDir, "solution.go"), Templates: []string{"solution.go.tmpl", importCurve}}, + {File: filepath.Join(csDir, "solver.go"), Templates: []string{"solver.go.tmpl", importCurve}}, } if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { panic(err) @@ -144,10 +150,6 @@ func main() { return } - plonkFriDir := filepath.Join(d.RootPath, "plonkfri") - groth16Dir := filepath.Join(d.RootPath, "groth16") - plonkDir := filepath.Join(d.RootPath, "plonk") - if err := os.MkdirAll(groth16Dir, 0700); err != nil { panic(err) } @@ -174,6 +176,22 @@ func main() { panic(err) // TODO handle } + // groth16 mpcsetup + entries = []bavard.Entry{ + {File: filepath.Join(groth16MpcSetupDir, "lagrange.go"), Templates: []string{"groth16/mpcsetup/lagrange.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "marshal.go"), Templates: []string{"groth16/mpcsetup/marshal.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "marshal_test.go"), Templates: []string{"groth16/mpcsetup/marshal_test.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "phase1.go"), Templates: []string{"groth16/mpcsetup/phase1.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "phase2.go"), Templates: []string{"groth16/mpcsetup/phase2.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "setup.go"), Templates: []string{"groth16/mpcsetup/setup.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "setup_test.go"), Templates: []string{"groth16/mpcsetup/setup_test.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "utils.go"), Templates: []string{"groth16/mpcsetup/utils.go.tmpl", importCurve}}, + } + + if err := bgen.Generate(d, "mpcsetup", "./template/zkpschemes/", entries...); err != nil { + panic(err) // TODO handle + } + // plonk entries = []bavard.Entry{ {File: filepath.Join(plonkDir, "verify.go"), Templates: []string{"plonk/plonk.verify.go.tmpl", importCurve}}, diff --git a/internal/generator/backend/template/representations/coeff.go.tmpl b/internal/generator/backend/template/representations/coeff.go.tmpl index 9ebe15314e..3ba357d1e6 100644 --- a/internal/generator/backend/template/representations/coeff.go.tmpl +++ b/internal/generator/backend/template/representations/coeff.go.tmpl @@ -1,7 +1,7 @@ import ( "github.com/consensys/gnark/constraint" - "math/big" "github.com/consensys/gnark/internal/utils" + "math/big" {{ template "import_fr" . }} ) @@ -27,8 +27,7 @@ func newCoeffTable(capacity int) CoeffTable { } - -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -51,7 +50,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } - + return cID +} + +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -60,8 +63,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } +// implements constraint.Field +type field struct{} -var _ constraint.CoeffEngine = &arithEngine{} +var _ constraint.Field = &field{} var ( two fr.Element @@ -78,11 +83,9 @@ func init() { } -// implements constraint.CoeffEngine -type arithEngine struct{} -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -91,55 +94,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() +} + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true } \ No newline at end of file diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl deleted file mode 100644 index c21a5a71d7..0000000000 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ /dev/null @@ -1,455 +0,0 @@ -import ( - "errors" - "fmt" - "io" - "runtime" - "time" - "sync" - "github.com/fxamacker/cbor/v2" - - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - "github.com/consensys/gnark/backend/witness" - - "math" - "github.com/consensys/gnark-crypto/ecc" - - {{ template "import_fr" . }} -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - - - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - - {{- $hintFunctions := "opt.HintFunctions"}} - {{- if ne .Curve "tinyfield"}} - hintFunctions := defineGkrHints(cs.GkrInfo, opt.HintFunctions) - {{- $hintFunctions = "hintFunctions"}} - {{- end}} - solution, err := newSolution(nbWires, {{$hintFunctions}}, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - - - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - - - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a,b,c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.{{.CurveID}} -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - - return int64(decoder.NumBytesRead()), nil -} - diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl deleted file mode 100644 index f5a35ea03e..0000000000 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ /dev/null @@ -1,453 +0,0 @@ -import ( - "fmt" - "io" - "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark-crypto/ecc" - "sync" - "runtime" - "math" - "errors" - "time" - - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - "github.com/consensys/gnark/backend/witness" - - {{ template "import_fr" . }} -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution( nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i:=0; i < len(coefficientsNegInv);i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - - - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.{{.Curve}}) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.{{.CurveID}} -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl deleted file mode 100644 index e323a614f5..0000000000 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ /dev/null @@ -1,279 +0,0 @@ -import ( - "errors" - "fmt" - "math/big" - "sync/atomic" - "strings" - "strconv" - "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/rs/zerolog" - {{ template "import_fr" . }} -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution( nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i :=0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} \ No newline at end of file diff --git a/internal/generator/backend/template/representations/solver.go.tmpl b/internal/generator/backend/template/representations/solver.go.tmpl new file mode 100644 index 0000000000..aa7b60cd4d --- /dev/null +++ b/internal/generator/backend/template/representations/solver.go.tmpl @@ -0,0 +1,650 @@ +import ( + "errors" + "fmt" + "math/big" + "sync/atomic" + "strings" + "strconv" + "runtime" + "sync" + "math" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + {{ template "import_fr" . }} +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a,b,c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public)-witnessOffset+len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i :=0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start) + i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + + + + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID:cID,VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k:= 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID:calldata[j],VID: calldata[j+1]} , &r) + j+=2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + + + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint,inst) + return solver.solveWithHint(&scratch.tHint) + } + + + return nil +} + + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + + + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID],&solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} + diff --git a/internal/generator/backend/template/representations/system.go.tmpl b/internal/generator/backend/template/representations/system.go.tmpl new file mode 100644 index 0000000000..7288fdb546 --- /dev/null +++ b/internal/generator/backend/template/representations/system.go.tmpl @@ -0,0 +1,371 @@ +import ( + "io" + "time" + "github.com/fxamacker/cbor/v2" + + "github.com/consensys/gnark/internal/backend/ioutils" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/backend/witness" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + {{ template "import_fr" . }} +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.{{.CurveID}} +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 134217728, + MaxMapPairs: 134217728, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + + + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + + + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + + + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index c312e2f48b..81eebde6bc 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -5,6 +5,7 @@ import ( "reflect" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "github.com/google/go-cmp/cmp" @@ -39,7 +40,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -67,10 +68,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -139,11 +140,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(),r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } var w circuit w.X = 1 @@ -153,9 +149,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(),scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(),r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl index da74e65321..aec745a7d6 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl @@ -5,7 +5,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl index 7414dab596..eaa9835d95 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl @@ -1,5 +1,7 @@ import ( {{ template "import_curve" . }} + {{ template "import_pedersen" . }} + "github.com/consensys/gnark/internal/utils" "io" ) @@ -61,14 +63,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -108,6 +120,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -117,13 +137,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -153,15 +185,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { - return dec.BytesRead(), err + if err := vk.Precompute(); err != nil { + return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -213,6 +246,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -221,6 +255,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -248,6 +299,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -279,6 +331,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 64987ebd1b..760cf22d24 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -1,15 +1,19 @@ import ( - "fmt" {{- template "import_fr" . }} {{- template "import_curve" . }} {{- template "import_backend_cs" . }} {{- template "import_fft" . }} + {{- template "import_pedersen" .}} "runtime" "math/big" "time" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/logger" ) @@ -20,7 +24,8 @@ import ( type Proof struct { Ar, Krs curve.G1Affine Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -34,83 +39,82 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } - - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + }(i))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() // we need to copy and filter the wireValues for each multi exp // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity - var wireValuesA, wireValuesB []fr.Element - chWireValuesA, chWireValuesB := make(chan struct{}, 1) , make(chan struct{}, 1) + var wireValuesA, wireValuesB []fr.Element + chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) go func() { - wireValuesA = make([]fr.Element , len(wireValues) - int(pk.NbInfinityA)) - for i,j :=0,0; j len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) + + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) - j := 0 // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i:=0; i < len(slice);i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -317,9 +331,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -327,7 +341,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -337,7 +351,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 1c7ca7dbbb..1795c3987d 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -1,9 +1,11 @@ import ( + "errors" {{- template "import_fr" . }} {{- template "import_curve" . }} {{- template "import_backend_cs" . }} {{- template "import_fft" . }} {{- template "import_pedersen" .}} + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint" "math/big" @@ -16,15 +18,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -34,21 +36,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -57,8 +59,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -75,17 +77,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -119,7 +124,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -130,28 +138,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -204,11 +226,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -220,8 +244,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality)-1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -234,17 +259,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -261,15 +291,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -280,16 +308,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { - return err + if err := vk.Precompute(); err != nil { + return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -304,7 +345,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -348,8 +389,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -362,9 +405,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -418,7 +464,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -430,8 +479,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -485,6 +534,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys,_, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -496,7 +561,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -504,6 +571,8 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -588,7 +657,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl index 29a7202c08..f87b9197dd 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl @@ -6,10 +6,13 @@ import ( "errors" "time" "io" - "math/big" + {{- if eq .Curve "BN254"}} "text/template" {{- end}} + {{- template "import_pedersen" .}} + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" ) @@ -21,10 +24,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K) - 1) } @@ -47,21 +48,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -72,8 +84,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl new file mode 100644 index 0000000000..89b75eb44f --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl @@ -0,0 +1,201 @@ +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + {{- template "import_fft" . }} + "github.com/consensys/gnark/internal/utils" +) + + + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal.go.tmpl new file mode 100644 index 0000000000..3af994ae04 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal.go.tmpl @@ -0,0 +1,165 @@ +import ( + "io" + + + {{- template "import_curve" . }} +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal_test.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal_test.go.tmpl new file mode 100644 index 0000000000..803b1863d9 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal_test.go.tmpl @@ -0,0 +1,63 @@ +import ( + "bytes" + "io" + "reflect" + "testing" + + + {{- template "import_curve" . }} + {{- template "import_backend_cs" . }} + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase1.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase1.go.tmpl new file mode 100644 index 0000000000..7c8c6fa540 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase1.go.tmpl @@ -0,0 +1,187 @@ +import ( + "crypto/sha256" + "errors" + "math" + "math/big" + + + {{- template "import_fr" . }} + {{- template "import_curve" . }} +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase2.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase2.go.tmpl new file mode 100644 index 0000000000..0afb32db79 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase2.go.tmpl @@ -0,0 +1,248 @@ +import ( + "crypto/sha256" + "errors" + "math/big" + + "github.com/consensys/gnark/constraint" + + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + {{- template "import_backend_cs" . }} +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup.go.tmpl new file mode 100644 index 0000000000..e60410b467 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup.go.tmpl @@ -0,0 +1,80 @@ +import ( + groth16 "github.com/consensys/gnark/backend/groth16/{{toLower .Curve}}" + + {{- template "import_curve" . }} + {{- template "import_fft" . }} +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup_test.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup_test.go.tmpl new file mode 100644 index 0000000000..a36c0c1b9d --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup_test.go.tmpl @@ -0,0 +1,184 @@ +import ( + "testing" + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + {{- template "import_backend_cs" . }} + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + {{- if ne (toLower .Curve) "bn254" }} + if testing.Short() { + t.Skip() + } + {{- end}} + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/utils.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/utils.go.tmpl new file mode 100644 index 0000000000..f7c67c036d --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/utils.go.tmpl @@ -0,0 +1,153 @@ +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl index 94d202d2b4..cc94c21ef1 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl @@ -1,4 +1,7 @@ import ( + "testing" + "fmt" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -6,7 +9,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -15,7 +17,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -101,8 +107,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl index e5ad79d370..1366e7ac92 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl @@ -2,7 +2,10 @@ import ( {{ template "import_curve" . }} {{ template "import_fft" . }} - + {{ template "import_pedersen" . }} + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "bytes" "math/big" @@ -10,6 +13,7 @@ import ( "github.com/leanovate/gopter" "github.com/leanovate/gopter/prop" + "github.com/leanovate/gopter/gen" "testing" ) @@ -73,13 +77,9 @@ func TestProofSerialization(t *testing.T) { func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - - properties := gopter.NewProperties(parameters) - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -107,6 +107,20 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) @@ -145,7 +159,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -162,7 +191,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -189,7 +218,20 @@ func TestProvingKeySerialization(t *testing.T) { pk.NbInfinityA = 1 pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) - pk.InfinityA[2] = true + pk.InfinityA[2] = true + + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) @@ -231,6 +273,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl index 760f22b3d1..2dfba35d65 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl @@ -1,11 +1,13 @@ import ( {{ template "import_curve" . }} {{ template "import_fr" . }} + {{ template "import_kzg" . }} + "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" "io" "errors" ) -// WriteTo writes binary encoding of Proof to w without point compression +// WriteRawTo writes binary encoding of Proof to w without point compression func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { return proof.writeTo(w, curve.RawEncoding()) } @@ -30,6 +32,7 @@ func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64 proof.BatchedProof.ClaimedValues, &proof.ZShiftedOpening.H, &proof.ZShiftedOpening.ClaimedValue, + proof.Bsb22Commitments, } for _, v := range toEncode { @@ -38,7 +41,7 @@ func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64 } } - return enc.BytesWritten(), nil + return enc.BytesWritten(), nil } // ReadFrom reads binary representation of Proof from r @@ -56,6 +59,7 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { &proof.BatchedProof.ClaimedValues, &proof.ZShiftedOpening.H, &proof.ZShiftedOpening.ClaimedValue, + &proof.Bsb22Commitments, } for _, v := range toDecode { @@ -64,7 +68,11 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { } } - return dec.BytesRead(), nil + if proof.Bsb22Commitments == nil { + proof.Bsb22Commitments = []kzg.Digest{} + } + + return dec.BytesRead(), nil } // WriteTo writes binary encoding of ProvingKey to w @@ -88,8 +96,15 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } n += n2 + // KZG key + n2, err = pk.Kzg.WriteTo(w) + if err != nil { + return + } + n += n2 + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { + if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -98,16 +113,17 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { // encode the size (nor does it convert from Montgomery to Regular form) // so we explicitly transmit []fr.Element toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, + pk.trace.Ql.Coefficients(), + pk.trace.Qr.Coefficients(), + pk.trace.Qm.Coefficients(), + pk.trace.Qo.Coefficients(), + pk.trace.Qk.Coefficients(), + coefficients(pk.trace.Qcp), + pk.lQk.Coefficients(), + pk.trace.S1.Coefficients(), + pk.trace.S2.Coefficients(), + pk.trace.S3.Coefficients(), + pk.trace.S, } for _, v := range toEncode { @@ -139,20 +155,30 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) + n2, err = pk.Kzg.ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + + pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) + + var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element + var qcp [][]fr.Element toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, + &ql, + &qr, + &qm, + &qo, + &qk, + &qcp, + &lqk, + &s1, + &s2, + &s3, + &pk.trace.S, } for _, v := range toDecode { @@ -161,6 +187,23 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { } } + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + pk.trace.Ql = iop.NewPolynomial(&ql, canReg) + pk.trace.Qr = iop.NewPolynomial(&qr, canReg) + pk.trace.Qm = iop.NewPolynomial(&qm, canReg) + pk.trace.Qo = iop.NewPolynomial(&qo, canReg) + pk.trace.Qk = iop.NewPolynomial(&qk, canReg) + pk.trace.S1 = iop.NewPolynomial(&s1, canReg) + pk.trace.S2 = iop.NewPolynomial(&s2, canReg) + pk.trace.S3 = iop.NewPolynomial(&s3, canReg) + + pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) + for i := range qcp { + pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) + } + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + pk.lQk = iop.NewPolynomial(&lqk, lagReg) + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil @@ -169,6 +212,15 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { // WriteTo writes binary encoding of VerifyingKey to w func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { + return vk.writeTo(w) +} + +// WriteRawTo writes binary encoding of VerifyingKey to w without point compression +func (vk *VerifyingKey) WriteRawTo(w io.Writer) (int64, error) { + return vk.writeTo(w, curve.RawEncoding()) +} + +func (vk *VerifyingKey) writeTo(w io.Writer, options ...func(*curve.Encoder)) (n int64, err error) { enc := curve.NewEncoder(w) toEncode := []interface{}{ @@ -185,6 +237,11 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.Qm, &vk.Qo, &vk.Qk, + vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + vk.CommitmentConstraintIndexes, } for _, v := range toEncode { @@ -213,6 +270,11 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.Qm, &vk.Qo, &vk.Qk, + &vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + &vk.CommitmentConstraintIndexes, } for _, v := range toDecode { @@ -221,5 +283,9 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { } } + if vk.Qcp == nil { + vk.Qcp = []kzg.Digest{} + } + return dec.BytesRead(), nil } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index af921fded2..a869528cf9 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -5,6 +5,8 @@ import ( "time" "sync" + "github.com/consensys/gnark/backend/witness" + {{ template "import_fr" . }} {{ template "import_curve" . }} {{ template "import_kzg" . }} @@ -12,10 +14,12 @@ import ( {{ template "import_backend_cs" . }} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/logger" "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/constraint" ) type Proof struct { @@ -29,17 +33,59 @@ type Proof struct { // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial H [3]kzg.Digest - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 + Bsb22Commitments []kzg.Digest + + // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2, qCPrime BatchedProof kzg.BatchOpeningProof // Opening proof of Z at zeta*mu ZShiftedOpening kzg.OpeningProof } -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +// Computing and verifying Bsb22 multi-commits explained in https://hackmd.io/x8KsadW3RRyX7YTCFJIkHg +func bsb22ComputeCommitmentHint(spr *cs.SparseR1CS, pk *ProvingKey, proof *Proof, cCommitments []*iop.Polynomial, res *fr.Element, commDepth int) solver.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] + committedValues := make([]fr.Element, pk.Domain[0].Cardinality) + offset := spr.GetNbPublicVariables() + for i := range ins { + committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) + } + var ( + err error + hashRes []fr.Element + ) + if _, err = committedValues[offset+commitmentInfo.CommitmentIndex].SetRandom(); err != nil { // Commitment injection constraint has qcp = 0. Safe to use for blinding. + return err + } + if _, err = committedValues[offset+spr.GetNbConstraints()-1].SetRandom(); err != nil { // Last constraint has qcp = 0. Safe to use for blinding + return err + } + pi2iop := iop.NewPolynomial(&committedValues, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}) + cCommitments[commDepth] = pi2iop.ShallowClone() + cCommitments[commDepth].ToCanonical(&pk.Domain[0]).ToRegular() + if proof.Bsb22Commitments[commDepth], err = kzg.Commit(cCommitments[commDepth].Coefficients(), pk.Kzg); err != nil { + return err + } + if hashRes, err = fr.Hash(proof.Bsb22Commitments[commDepth].Marshal(), []byte("BSB22-Plonk"), 1); err != nil { + return err + } + res.Set(&hashRes[0]) // TODO @Tabaie use CommitmentIndex for this; create a new variable CommitmentConstraintIndex for other uses + res.BigInt(outs[0]) + + return nil + } +} + +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + + log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", spr.GetNbConstraints()).Str("backend", "plonk").Logger() + + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() start := time.Now() // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -50,53 +96,101 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // result proof := &Proof{} - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + commitmentVal := make([]fr.Element, len(commitmentInfo)) // TODO @Tabaie get rid of this + cCommitments := make([]*iop.Polynomial, len(commitmentInfo)) + proof.Bsb22Commitments = make([]kzg.Digest, len(commitmentInfo)) + for i := range commitmentInfo { + opt.SolverOpts = append(opt.SolverOpts, solver.OverrideHint(commitmentInfo[i].HintID, + bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err + } + // TODO @gbotrel deal with that conversion lazily + lcCommitments := make([]*iop.Polynomial, len(cCommitments)) + for i := range cCommitments { + lcCommitments[i] = cCommitments[i].Clone(int(pk.Domain[1].Cardinality)).ToLagrangeCoset(&pk.Domain[1]) // lagrange coset form + } + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall,lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall,lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall,lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err + // l, r, o and blinded versions + var ( + wliop, + wriop, + woiop, + bwliop, + bwriop, + bwoiop *iop.Polynomial + ) + var wgLRO sync.WaitGroup + wgLRO.Add(3) + go func() { + // we keep in lagrange regular form since iop.BuildRatioCopyConstraint prefers it in this form. + wliop = iop.NewPolynomial(&evaluationLDomainSmall, lagReg) + // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. + bwliop = wliop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + wriop = iop.NewPolynomial(&evaluationRDomainSmall, lagReg) + bwriop = wriop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + woiop = iop.NewPolynomial(&evaluationODomainSmall, lagReg) + bwoiop = woiop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness } + // start computing lcqk + var lcqk *iop.Polynomial + chLcqk := make(chan struct{}, 1) + go func() { + // compute qk in canonical basis, completed with the public inputs + // We copy the coeffs of qk to pk is not mutated + lqkcoef := pk.lQk.Coefficients() + qkCompletedCanonical := make([]fr.Element, len(lqkcoef)) + copy(qkCompletedCanonical, fw[:len(spr.Public)]) + copy(qkCompletedCanonical[len(spr.Public):], lqkcoef[len(spr.Public):]) + for i := range commitmentInfo { + qkCompletedCanonical[spr.GetNbPublicVariables()+commitmentInfo[i].CommitmentIndex] = commitmentVal[i] + } + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + lcqk = iop.NewPolynomial(&qkCompletedCanonical, canReg) + lcqk.ToLagrangeCoset(&pk.Domain[1]) + close(chLcqk) + }() + // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { + if err := bindPublicData(&fs, "gamma", *pk.Vk, fw[:len(spr.Public)], proof.Bsb22Commitments); err != nil { return nil, err } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) + + // wait for polys to be blinded + wgLRO.Wait() + if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Kzg); err != nil { + return nil, err + } + + gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) // TODO @Tabaie @ThomasPiellard add BSB commitment here? if err != nil { return nil, err } @@ -109,17 +203,30 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen var beta fr.Element beta.SetBytes(bbeta) + // l, r, o are already blinded + wgLRO.Add(3) + go func() { + bwliop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwriop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwoiop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... + // note that wliop, wriop and woiop are fft'ed (mutated) in the process. ziop, err := iop.BuildRatioCopyConstraint( []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), + wliop, + wriop, + woiop, }, - pk.Permutation, + pk.trace.S, beta, gamma, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, @@ -130,73 +237,32 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } + chZ := make(chan error, 1) + var bwziop, bwsziop *iop.Polynomial + var alpha fr.Element + go func() { + bwziop = ziop // iop.NewWrappedPolynomial(&ziop) + bwziop.Blind(2) + proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Kzg, runtime.NumCPU()*2) + if err != nil { + chZ <- err + } - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } + // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) + alpha, err = deriveRandomness(&fs, "alpha", &proof.Z) + if err != nil { + chZ <- err + } - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) + // Store z(g*x), without reallocating a slice + bwsziop = bwziop.ShallowClone().Shift(1) + bwsziop.ToLagrangeCoset(&pk.Domain[1]) + chZ <- nil + close(chZ) + }() // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { + fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element, pi2QcPrime []fr.Element) fr.Element { // TODO @Tabaie make use of the fact that qCPrime is a selector: sparse and binary var ic, tmp fr.Element @@ -207,20 +273,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen ic.Add(&ic, &tmp) tmp.Mul(&fqo, &o) ic.Add(&ic, &tmp).Add(&ic, &fqk) + nbComms := len(commitmentInfo) + for i := range commitmentInfo { + tmp.Mul(&pi2QcPrime[i], &pi2QcPrime[i+nbComms]) + ic.Add(&ic, &tmp) + } return ic } fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - + u := &pk.Domain[0].FrMultiplicativeGen var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) + b.Mul(&beta, &fid) + a.Add(&b, &l).Add(&a, &gamma) + b.Mul(&b, u) + tmp.Add(&b, &r).Add(&tmp, &gamma) a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) + tmp.Mul(&b, u).Add(&tmp, &o).Add(&tmp, &gamma) a.Mul(&a, &tmp).Mul(&a, &fz) b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) @@ -240,11 +310,11 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen return one } - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone + // 0 , 1 , 2, 3 , 4 , 5 , 6 , 7, 8 , 9 , 10, 11, 12, 13, 14, 15:15+nbComm , 15+nbComm:15+2×nbComm + // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk ,lone, Bsb22Commitments, qCPrime fm := func(x ...fr.Element) fr.Element { - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) + a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2], x[15:]) b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) c := fone(x[7], x[14]) @@ -252,27 +322,49 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen return c } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, + + // wait for lcqk + <-chLcqk + + // wait for Z part + if err := <-chZ; err != nil { + return proof, err + } + + // wait for l, r o lagrange coset conversion + wgLRO.Wait() + + toEval := []*iop.Polynomial{ bwliop, bwriop, bwoiop, - widiop, - ws1, - ws2, - ws3, + pk.lcIdIOP, + pk.lcS1, + pk.lcS2, + pk.lcS3, bwziop, bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) + pk.lcQl, + pk.lcQr, + pk.lcQm, + pk.lcQo, + lcqk, + pk.lLoneIOP, + } + toEval = append(toEval, lcCommitments...) // TODO: Add this at beginning + toEval = append(toEval, pk.lcQcp...) + systemEvaluation, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, toEval...) if err != nil { return nil, err } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) + // open blinded Z at zeta*z + chbwzIOP := make(chan struct{}, 1) + go func() { + bwziop.ToCanonical(&pk.Domain[1]).ToRegular() + close(chbwzIOP) + }() + + h, err := iop.DivideByXMinusOne(systemEvaluation, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) // TODO Rename to DivideByXNMinusOne or DivideByVanishingPoly etc if err != nil { return nil, err } @@ -282,7 +374,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen h.Coefficients()[:pk.Domain[0].Cardinality+2], h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { + proof, pk.Kzg); err != nil { return nil, err } @@ -292,45 +384,72 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen return nil, err } - // compute evaluations of (blinded version of) l, r, o, z at zeta + // compute evaluations of (blinded version of) l, r, o, z, qCPrime at zeta var blzeta, brzeta, bozeta fr.Element - + qcpzeta := make([]fr.Element, len(commitmentInfo)) + var wgEvals sync.WaitGroup wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) + evalAtZeta := func(poly *iop.Polynomial, res *fr.Element) { + poly.ToCanonical(&pk.Domain[1]).ToRegular() + *res = poly.Evaluate(zeta) wgEvals.Done() - }() + } + go evalAtZeta(bwliop, &blzeta) + go evalAtZeta(bwriop, &brzeta) + go evalAtZeta(bwoiop, &bozeta) + evalQcpAtZeta := func(begin, end int) { + for i := begin; i < end; i++ { + qcpzeta[i] = pk.trace.Qcp[i].Evaluate(zeta) + } + } + utils.Parallelize(len(commitmentInfo), evalQcpAtZeta) - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) + <-chbwzIOP proof.ZShiftedOpening, err = kzg.Open( bwziop.Coefficients()[:bwziop.BlindedSize()], zetaShifted, - pk.Vk.KZGSRS, + pk.Kzg, ) if err != nil { return nil, err } - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue + // start to compute foldedH and foldedHDigest while computeLinearizedPolynomial runs. + computeFoldedH := make(chan struct{}, 1) + var foldedH []fr.Element + var foldedHDigest kzg.Digest + go func() { + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) + var bZetaPowerm, bSize big.Int + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + var zetaPowerm fr.Element + zetaPowerm.Exp(zeta, &bSize) + zetaPowerm.BigInt(&bZetaPowerm) + foldedHDigest = proof.H[2] + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) + + // foldedH = h1 + ζ*h2 + ζ²*h3 + foldedH = h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] + h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] + utils.Parallelize(len(foldedH), func(start, end int) { + for i := start; i < end; i++ { + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 + } + }) + close(computeFoldedH) + }() + + wgEvals.Wait() // wait for the evaluations var ( linearizedPolynomialCanonical []fr.Element @@ -338,10 +457,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen errLPoly error ) - wgEvals.Wait() // wait for the evaluations + // blinded z evaluated at u*zeta + bzuzeta := proof.ZShiftedOpening.ClaimedValue // compute the linearization polynomial r at zeta // (goal: save committing separately to z, ql, qr, qm, qo, k + // note: we linearizedPolynomialCanonical reuses bwziop memory linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, @@ -351,66 +472,52 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen gamma, zeta, bzuzeta, + qcpzeta, bwziop.Coefficients()[:bwziop.BlindedSize()], + coefficients(cCommitments), pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Kzg, runtime.NumCPU()*2) if errLPoly != nil { return nil, errLPoly } + // wait for foldedH and foldedHDigest + <-computeFoldedH + // Batch open the first list of polynomials + polysQcp := coefficients(pk.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + // offset := len(polysQcp) + polysToOpen[0] = foldedH + polysToOpen[1] = linearizedPolynomialCanonical + polysToOpen[2] = bwliop.Coefficients()[:bwliop.BlindedSize()] + polysToOpen[3] = bwriop.Coefficients()[:bwriop.BlindedSize()] + polysToOpen[4] = bwoiop.Coefficients()[:bwoiop.BlindedSize()] + polysToOpen[5] = pk.trace.S1.Coefficients() + polysToOpen[6] = pk.trace.S2.Coefficients() + + digestsToOpen := make([]curve.G1Affine, len(pk.Vk.Qcp)+7) + copy(digestsToOpen[7:], pk.Vk.Qcp) + // offset = len(pk.Vk.Qcp) + digestsToOpen[0] = foldedHDigest + digestsToOpen[1] = linearizedPolynomialDigest + digestsToOpen[2] = proof.LRO[0] + digestsToOpen[3] = proof.LRO[1] + digestsToOpen[4] = proof.LRO[2] + digestsToOpen[5] = pk.Vk.S[0] + digestsToOpen[6] = pk.Vk.S[1] + proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, + polysToOpen, + digestsToOpen, zeta, hFunc, - pk.Vk.KZGSRS, + pk.Kzg, ) log.Debug().Dur("took", time.Since(start)).Msg("prover done") @@ -423,21 +530,29 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } +func coefficients(p []*iop.Polynomial) [][]fr.Element { + res := make([][]fr.Element, len(p)) + for i, pI := range p { + res[i] = pI.Coefficients() + } + return res +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) chCommit1 := make(chan struct{}, 1) go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) + proof.LRO[0], err0 = kzg.Commit(bcl, kzgPk, n) close(chCommit0) }() go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) + proof.LRO[1], err1 = kzg.Commit(bcr, kzgPk, n) close(chCommit1) }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { + if proof.LRO[2], err2 = kzg.Commit(bco, kzgPk, n); err2 != nil { return err2 } <-chCommit0 @@ -450,20 +565,20 @@ func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { return err1 } -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) chCommit1 := make(chan struct{}, 1) go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) + proof.H[0], err0 = kzg.Commit(h1, kzgPk, n) close(chCommit0) }() go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) + proof.H[1], err1 = kzg.Commit(h2, kzgPk, n) close(chCommit1) }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { + if proof.H[2], err2 = kzg.Commit(h3, kzgPk, n); err2 != nil { return err2 } <-chCommit0 @@ -476,41 +591,6 @@ func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error return err1 } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta @@ -522,7 +602,7 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element @@ -533,13 +613,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) + s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) + // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) + tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -570,43 +649,51 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { + s3canonical := pk.trace.S3.Coefficients() + utils.Parallelize(len(blindedZCanonical), func(start, end int) { - var t0, t1 fr.Element + var t, t0, t1 fr.Element for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + t.Mul(&blindedZCanonical[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - if i < len(pk.S3Canonical) { + if i < len(s3canonical) { - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + t0.Mul(&s3canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - linPol[i].Add(&linPol[i], &t0) + t.Add(&t, &t0) } - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) + t.Mul(&t, &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - if i < len(pk.Qm) { + cql := pk.trace.Ql.Coefficients() + cqr := pk.trace.Qr.Coefficients() + cqm := pk.trace.Qm.Coefficients() + cqo := pk.trace.Qo.Coefficients() + cqk := pk.trace.Qk.Coefficients() + if i < len(cqm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) + t1.Mul(&cqm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&cql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) + t.Add(&t, &t0) // linPol = linPol + l(ζ)*Ql(X) + + t0.Mul(&cqr[i], &rZeta) + t.Add(&t, &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) + t0.Mul(&cqo[i], &oZeta).Add(&t0, &cqk[i]) + t.Add(&t, &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) + for j := range qcpZeta { + t0.Mul(&pi2Canonical[j][i], &qcpZeta[j]) + t.Add(&t, &t0) + } } t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation + blindedZCanonical[i].Add(&t, &t0) // finish the computation } }) - return linPol + return blindedZCanonical } diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index 810078c132..0bacab893c 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -5,48 +5,31 @@ import ( {{- template "import_fft" . }} {{- template "import_backend_cs" . }} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" - - kzgg "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk/internal" + "github.com/consensys/gnark/constraint" ) -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 +// Trace stores a plonk trace as columns +type Trace struct { + + // Constants describing a plonk circuit. The first entries + // of LQk (whose index correspond to the public inputs) are set to 0, and are to be + // completed by the prover. At those indices i (so from 0 to nb_public_variables), LQl[i]=-1 + // so the first nb_public_variables constraints look like this: + // -1*Wire[i] + 0* + 0 . It is zero when the constant coefficient is replaced by Wire[i]. + Ql, Qr, Qm, Qo, Qk *iop.Polynomial + Qcp []*iop.Polynomial + + // Polynomials representing the splitted permutation. The full permutation's support is 3*N where N=nb wires. + // The set of interpolation is of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 } // VerifyingKey stores the data needed to verify a proof: @@ -55,6 +38,7 @@ type ProvingKey struct { // * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs // * Commitments to S1, S2, S3 type VerifyingKey struct { + // Size circuit Size uint64 SizeInv fr.Element @@ -62,7 +46,7 @@ type VerifyingKey struct { NbPublicVariables uint64 // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS + Kzg kzg.VerifyingKey // cosetShift generator of the coset on the small domain CosetShift fr.Element @@ -70,120 +54,271 @@ type VerifyingKey struct { // S commitments to S1, S2, S3 S [3]kzg.Digest - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. // In particular Qk is not complete. Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 } -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + var pk ProvingKey var vk VerifyingKey - - // The verifying key shares data with the proving key pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) - nbConstraints := len(spr.Constraints) + // step 0: set the fft domains + pk.initDomains(spr) - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk - if err := pk.InitKZG(srs); err != nil { + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { return nil, nil, err } - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + } + pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) + pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) + + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) + + + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ } - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() - // Commit to the polynomials to set up the verifying key var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} - return &pk, &vk, nil +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } } @@ -191,38 +326,45 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l∥r∥o) = (l∥r∥o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l∥r∥o is the concatenation of the indices of l, r, o in +// , where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 } // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID for i := 0; i < len(spr.Public); i++ { lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ } + // init cycle: // map ID -> last position the ID was seen cycle := make([]int64, nbVariables) @@ -234,92 +376,54 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { if cycle[lro[i]] != -1 { // if != -1, it means we already encountered this value // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] + permutation[i] = cycle[lro[i]] } cycle[lro[i]] = int64(i) } // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] } } -} -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() + pt.S = permutation } -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// | -// | Permutation -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { + nbElmts := int(domain.Cardinality) - nbElmts := int(pk.Domain[0].Cardinality) + var res [3]*iop.Polynomial // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) + evaluationIDSmallDomain := getSupportPermutation(domain) // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) } - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res } -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { res := make([]fr.Element, 3*domain.Cardinality) @@ -334,39 +438,4 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { } return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} +} \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index 1f7f8b83fc..43e5bd3c6f 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -5,6 +5,9 @@ import ( "time" "io" {{ template "import_fr" . }} + {{if eq .Curve "BN254"}} + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + {{end}} {{ template "import_kzg" . }} {{ template "import_curve" . }} {{if eq .Curve "BN254"}} @@ -20,7 +23,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "{{ toLower .CurveID }}").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "{{ toLower .Curve }}").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -32,7 +35,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -66,25 +69,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } + pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationQrDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationQmDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -192,10 +197,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -240,10 +245,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -329,9 +338,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -359,9 +368,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/kvstore/kvstore.go b/internal/kvstore/kvstore.go new file mode 100644 index 0000000000..6c580cc2fd --- /dev/null +++ b/internal/kvstore/kvstore.go @@ -0,0 +1,38 @@ +// Package kvstore implements simple key-value store +// +// It is without synchronization and allows any comparable keys. The main use of +// this package is for sharing singletons when building a circuit. +package kvstore + +import ( + "reflect" +) + +type Store interface { + SetKeyValue(key, value any) + GetKeyValue(key any) (value any) +} + +type impl struct { + db map[any]any +} + +func New() Store { + return &impl{ + db: make(map[any]any), + } +} + +func (c *impl) SetKeyValue(key, value any) { + if !reflect.TypeOf(key).Comparable() { + panic("key type not comparable") + } + c.db[key] = value +} + +func (c *impl) GetKeyValue(key any) any { + if !reflect.TypeOf(key).Comparable() { + panic("key type not comparable") + } + return c.db[key] +} diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index be2ff23c98..dcd297d185 100644 Binary files a/internal/stats/latest.stats and b/internal/stats/latest.stats differ diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go index 1f6ed7764c..88b147fb80 100644 --- a/internal/stats/snippet.go +++ b/internal/stats/snippet.go @@ -7,8 +7,8 @@ import ( "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/sw_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" @@ -116,10 +116,8 @@ func initSnippets() { dummyG2.Y.A1 = newVariable() // e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB) - resMillerLoop, _ := sw_bls12377.MillerLoop(api, []sw_bls12377.G1Affine{dummyG1}, []sw_bls12377.G2Affine{dummyG2}) + _, _ = sw_bls12377.Pair(api, []sw_bls12377.G1Affine{dummyG1}, []sw_bls12377.G2Affine{dummyG2}) - // performs the final expo - _ = sw_bls12377.FinalExponentiation(api, resMillerLoop) }, ecc.BW6_761) registerSnippet("pairing_bls24315", func(api frontend.API, newVariable func() frontend.Variable) { @@ -138,10 +136,8 @@ func initSnippets() { dummyG2.Y.B1.A1 = newVariable() // e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB) - resMillerLoop, _ := sw_bls24315.MillerLoop(api, []sw_bls24315.G1Affine{dummyG1}, []sw_bls24315.G2Affine{dummyG2}) + _, _ = sw_bls24315.Pair(api, []sw_bls24315.G1Affine{dummyG1}, []sw_bls24315.G2Affine{dummyG2}) - // performs the final expo - _ = sw_bls24315.FinalExponentiation(api, resMillerLoop) }, ecc.BW6_633) } diff --git a/internal/tinyfield/element.go b/internal/tinyfield/element.go index ec15e0514b..f987f62b2d 100644 --- a/internal/tinyfield/element.go +++ b/internal/tinyfield/element.go @@ -27,6 +27,7 @@ import ( "strconv" "strings" + "github.com/bits-and-blooms/bitset" "github.com/consensys/gnark-crypto/field/hash" "github.com/consensys/gnark-crypto/field/pool" ) @@ -299,7 +300,7 @@ func (z *Element) SetRandom() (*Element, error) { return nil, err } - // Clear unused bits in in the most signicant byte to increase probability + // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<= 0; i-- { - if zeroes[i] { + if zeroes.Test(uint(i)) { continue } res[i].Mul(&res[i], &accumulator) diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index d6c74b751d..93664bddb2 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -317,7 +317,8 @@ func TestElementReduce(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, s := range testValues { + for i := range testValues { + s := testValues[i] expected := s reduce(&s) _reduceGeneric(&expected) @@ -701,7 +702,8 @@ func TestElementAdd(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -736,11 +738,12 @@ func TestElementAdd(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -810,7 +813,8 @@ func TestElementSub(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -845,11 +849,12 @@ func TestElementSub(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -919,7 +924,8 @@ func TestElementMul(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -973,11 +979,12 @@ func TestElementMul(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -1055,7 +1062,8 @@ func TestElementDiv(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -1091,11 +1099,12 @@ func TestElementDiv(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -1166,7 +1175,8 @@ func TestElementExp(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -1201,11 +1211,12 @@ func TestElementExp(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -1277,7 +1288,8 @@ func TestElementSquare(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1349,7 +1361,8 @@ func TestElementInverse(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1421,7 +1434,8 @@ func TestElementSqrt(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1493,7 +1507,8 @@ func TestElementDouble(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1565,7 +1580,8 @@ func TestElementNeg(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element diff --git a/internal/utils/convert.go b/internal/utils/convert.go index 7ca9ba1674..3bc19c6758 100644 --- a/internal/utils/convert.go +++ b/internal/utils/convert.go @@ -15,6 +15,7 @@ package utils import ( + "math" "math/big" "reflect" ) @@ -60,7 +61,7 @@ func FromInterface(input interface{}) big.Int { case int32: r.SetInt64(int64(v)) case int64: - r.SetInt64(int64(v)) + r.SetInt64(v) case int: r.SetInt64(int64(v)) case string: @@ -87,3 +88,31 @@ func FromInterface(input interface{}) big.Int { return r } + +func IntSliceSliceToUint64SliceSlice(in [][]int) [][]uint64 { + res := make([][]uint64, len(in)) + for i := range in { + res[i] = make([]uint64, len(in[i])) + for j := range in[i] { + if in[i][j] < 0 { + panic("negative value in int slice") + } + res[i][j] = uint64(in[i][j]) + } + } + return res +} + +func Uint64SliceSliceToIntSliceSlice(in [][]uint64) [][]int { + res := make([][]int, len(in)) + for i := range in { + res[i] = make([]int, len(in[i])) + for j := range in[i] { + if in[i][j] >= math.MaxInt { + panic("too large") + } + res[i][j] = int(in[i][j]) + } + } + return res +} diff --git a/internal/utils/heap.go b/internal/utils/heap.go new file mode 100644 index 0000000000..944bcb5e03 --- /dev/null +++ b/internal/utils/heap.go @@ -0,0 +1,63 @@ +package utils + +// An IntHeap is a min-heap of linear expressions. It facilitates merging k-linear expressions. +// +// The code is identical to https://pkg.go.dev/container/heap but replaces interfaces with concrete +// type to avoid memory overhead. +type IntHeap []int + +func (h *IntHeap) less(i, j int) bool { return (*h)[i] < (*h)[j] } +func (h *IntHeap) swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } + +// Heapify establishes the heap invariants required by the other routines in this package. +// Heapify is idempotent with respect to the heap invariants +// and may be called whenever the heap invariants may have been invalidated. +// The complexity is O(n) where n = len(*h). +func (h *IntHeap) Heapify() { + // heapify + n := len(*h) + for i := n/2 - 1; i >= 0; i-- { + h.down(i, n) + } +} + +// Pop removes and returns the minimum element (according to Less) from the heap. +// The complexity is O(log n) where n = len(*h). +// Pop is equivalent to Remove(h, 0). +func (h *IntHeap) Pop() { + n := len(*h) - 1 + h.swap(0, n) + h.down(0, n) + *h = (*h)[0:n] +} + +func (h *IntHeap) up(j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.less(j, i) { + break + } + h.swap(i, j) + j = i + } +} + +func (h *IntHeap) down(i0, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.less(j2, j1) { + j = j2 // = 2*i + 2 // right child + } + if !h.less(j, i) { + break + } + h.swap(i, j) + i = j + } + return i > i0 +} diff --git a/internal/utils/search.go b/internal/utils/search.go new file mode 100644 index 0000000000..fd9bcb460b --- /dev/null +++ b/internal/utils/search.go @@ -0,0 +1,26 @@ +package utils + +import "sort" + +// FindInSlice attempts to find the target in increasing slice x. +// If not found, returns false and the index where the target would be inserted. +func FindInSlice(x []int, target int) (int, bool) { + return sort.Find(len(x), func(i int) int { + return target - x[i] + }) +} + +// MultiListSeeker looks up increasing integers in a list of increasing lists of integers. +type MultiListSeeker [][]int + +// Seek returns the index of the earliest list where n is found, or -1 if not found. +func (s MultiListSeeker) Seek(n int) int { + for i, l := range s { + j, found := FindInSlice(l, n) + s[i] = l[j:] + if found { + return i + } + } + return -1 +} diff --git a/profile/profile.go b/profile/profile.go index 3d3778d995..fe3e58f3c2 100644 --- a/profile/profile.go +++ b/profile/profile.go @@ -92,7 +92,7 @@ func Start(options ...Option) *Profile { log := logger.Logger() if p.filePath == "" { - log.Warn().Msg("gnark profiling enabled [not writting to disk]") + log.Warn().Msg("gnark profiling enabled [not writing to disk]") } else { log.Info().Str("path", p.filePath).Msg("gnark profiling enabled") } @@ -131,7 +131,7 @@ func (p *Profile) Stop() { f.Close() log.Info().Str("path", p.filePath).Msg("gnark profiling disabled") } else { - log.Warn().Msg("gnark profiling disabled [not writting to disk]") + log.Warn().Msg("gnark profiling disabled [not writing to disk]") } } @@ -144,7 +144,7 @@ func (p *Profile) NbConstraints() int { // Top return a similar output than pprof top command func (p *Profile) Top() string { r := report.NewDefault(&p.pprof, report.Options{ - OutputFormat: report.Text, + OutputFormat: report.Tree, CompactLabels: true, NodeFraction: 0.005, EdgeFraction: 0.001, diff --git a/profile/profile_test.go b/profile/profile_test.go index 0505331a90..f198f370e1 100644 --- a/profile/profile_test.go +++ b/profile/profile_test.go @@ -15,8 +15,18 @@ type Circuit struct { A frontend.Variable } +type obj struct { +} + func (circuit *Circuit) Define(api frontend.API) error { - api.AssertIsEqual(api.Mul(circuit.A, circuit.A), circuit.A) + var o obj + o.Define(api, circuit.A) + // api.AssertIsEqual(api.Mul(circuit.A, circuit.A), circuit.A) + return nil +} + +func (o *obj) Define(api frontend.API, A frontend.Variable) error { + api.AssertIsEqual(api.Mul(A, A), A) return nil } @@ -29,12 +39,23 @@ func Example() { _, _ = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &Circuit{}) p.Stop() + // expected output fmt.Println(p.Top()) + const _ = `Showing nodes accounting for 2, 100% of 2 total +----------------------------------------------------------+------------- + flat flat% sum% cum cum% calls calls% + context +----------------------------------------------------------+------------- + 1 100% | profile_test.(*Circuit).Define profile/profile_test.go:21 + 1 50.00% 50.00% 1 50.00% | r1cs.(*builder).AssertIsEqual frontend/cs/r1cs/api_assertions.go:37 +----------------------------------------------------------+------------- + 1 100% | profile_test.(*Circuit).Define profile/profile_test.go:21 + 1 50.00% 100% 1 50.00% | r1cs.(*builder).Mul frontend/cs/r1cs/api.go:221 +----------------------------------------------------------+------------- + 0 0% 100% 2 100% | profile_test.(*Circuit).Define profile/profile_test.go:21 + 1 50.00% | r1cs.(*builder).AssertIsEqual frontend/cs/r1cs/api_assertions.go:37 + 1 50.00% | r1cs.(*builder).Mul frontend/cs/r1cs/api.go:221 +----------------------------------------------------------+-------------` + fmt.Println(p.NbConstraints()) - fmt.Println(p.Top()) // Output: // 2 - // Showing nodes accounting for 2, 100% of 2 total - // flat flat% sum% cum cum% - // 1 50.00% 50.00% 2 100% profile_test.(*Circuit).Define profile/profile_test.go:19 - // 1 50.00% 100% 1 50.00% r1cs.(*builder).AssertIsEqual frontend/cs/r1cs/api_assertions.go:37 } diff --git a/profile/profile_worker.go b/profile/profile_worker.go index b63499475b..97817fe898 100644 --- a/profile/profile_worker.go +++ b/profile/profile_worker.go @@ -5,6 +5,7 @@ import ( "strings" "sync" "sync/atomic" + "unicode" "github.com/google/pprof/profile" ) @@ -63,31 +64,18 @@ func collectSample(pc []uintptr) { for { frame, more := frames.Next() - if strings.HasSuffix(frame.Function, ".func1") { - // TODO @gbotrel filter anonymous func better - continue - } - - // to avoid aving a location that concentrates 99% of the calls, we transfer the "addConstraint" - // occuring in Mul to the previous level in the stack - if strings.Contains(frame.Function, "github.com/consensys/gnark/frontend/cs/r1cs.(*builder).Mul") { - continue + if strings.Contains(frame.Function, "frontend.parseCircuit") { + // we stop; previous frame was the .Define definition of the circuit + break } - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).Mul") { + if strings.HasSuffix(frame.Function, ".func1") { + // TODO @gbotrel filter anonymous func better continue } - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).split") { - continue - } - - // with scs.Builder (Plonk) Add and Sub always add a constraint --> we record the caller as the constraint adder - // but in the future we may record a different type of sample for these - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).Add") { - continue - } - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).Sub") { + // filter internal builder functions + if filterSCSPrivateFunc(frame.Function) || filterR1CSPrivateFunc(frame.Function) { continue } @@ -107,6 +95,7 @@ func collectSample(pc []uintptr) { sessions[i].onceSetName.Do(func() { // once per profile session, we set the "name of the binary" // here we grep the struct name where "Define" exist: hopefully the circuit Name + // note: this won't work well for nested Define calls. fe := strings.Split(frame.Function, "/") circuitName := strings.TrimSuffix(fe[len(fe)-1], ".Define") sessions[i].pprof.Mapping = []*profile.Mapping{ @@ -114,7 +103,7 @@ func collectSample(pc []uintptr) { } }) } - break + // break --> we break when we hit frontend.parseCircuit; in case we have nested Define calls in the stack. } } @@ -123,3 +112,27 @@ func collectSample(pc []uintptr) { } } + +func filterSCSPrivateFunc(f string) bool { + const scsPrefix = "github.com/consensys/gnark/frontend/cs/scs.(*builder)." + if strings.HasPrefix(f, scsPrefix) && len(f) > len(scsPrefix) { + // filter plonk frontend private APIs from the trace. + c := []rune(f)[len(scsPrefix)] + if unicode.IsLower(c) { + return true + } + } + return false +} + +func filterR1CSPrivateFunc(f string) bool { + const r1csPrefix = "github.com/consensys/gnark/frontend/cs/r1cs.(*builder)." + if strings.HasPrefix(f, r1csPrefix) && len(f) > len(r1csPrefix) { + // filter r1cs frontend private APIs from the trace. + c := []rune(f)[len(r1csPrefix)] + if unicode.IsLower(c) { + return true + } + } + return false +} diff --git a/std/accumulator/merkle/verify.go b/std/accumulator/merkle/verify.go index 35caff2ce8..3ec92412d6 100644 --- a/std/accumulator/merkle/verify.go +++ b/std/accumulator/merkle/verify.go @@ -62,7 +62,7 @@ type MerkleProof struct { // leafSum returns the hash created from data inserted to form a leaf. // Without domain separation. -func leafSum(api frontend.API, h hash.Hash, data frontend.Variable) frontend.Variable { +func leafSum(api frontend.API, h hash.FieldHasher, data frontend.Variable) frontend.Variable { h.Reset() h.Write(data) @@ -73,7 +73,7 @@ func leafSum(api frontend.API, h hash.Hash, data frontend.Variable) frontend.Var // nodeSum returns the hash created from data inserted to form a leaf. // Without domain separation. -func nodeSum(api frontend.API, h hash.Hash, a, b frontend.Variable) frontend.Variable { +func nodeSum(api frontend.API, h hash.FieldHasher, a, b frontend.Variable) frontend.Variable { h.Reset() h.Write(a, b) @@ -86,7 +86,7 @@ func nodeSum(api frontend.API, h hash.Hash, a, b frontend.Variable) frontend.Var // true if the first element of the proof set is a leaf of data in the Merkle // root. False is returned if the proof set or Merkle root is nil, and if // 'numLeaves' equals 0. -func (mp *MerkleProof) VerifyProof(api frontend.API, h hash.Hash, leaf frontend.Variable) { +func (mp *MerkleProof) VerifyProof(api frontend.API, h hash.FieldHasher, leaf frontend.Variable) { depth := len(mp.Path) - 1 sum := leafSum(api, h, mp.Path[0]) diff --git a/std/accumulator/merkle/verify_test.go b/std/accumulator/merkle/verify_test.go index 2b650a2ca8..51582df727 100644 --- a/std/accumulator/merkle/verify_test.go +++ b/std/accumulator/merkle/verify_test.go @@ -19,17 +19,18 @@ package merkle import ( "bytes" "crypto/rand" + "os" + "testing" + "github.com/consensys/gnark-crypto/accumulator/merkletree" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/logger" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" - "os" - "testing" ) // MerkleProofTest used for testing only @@ -99,7 +100,7 @@ func TestVerify(t *testing.T) { os.Exit(-1) } - // verfiy the proof in plain go + // verify the proof in plain go verified := merkletree.VerifyProof(hGo, merkleRoot, proofPath, proofIndex, numLeaves) if !verified { t.Fatal("The merkle proof in plain go should pass") @@ -119,7 +120,7 @@ func TestVerify(t *testing.T) { t.Fatal(err) } logger.SetOutput(os.Stdout) - err = cc.IsSolved(w, backend.IgnoreSolverError(), backend.WithCircuitLogger(logger.Logger())) + err = cc.IsSolved(w, solver.WithLogger(logger.Logger())) if err != nil { t.Fatal(err) } diff --git a/std/algebra/doc.go b/std/algebra/doc.go new file mode 100644 index 0000000000..18e70aeac5 --- /dev/null +++ b/std/algebra/doc.go @@ -0,0 +1,14 @@ +// Package algebra implements: +// - base finite field 𝔽p arithmetic, +// - extension finite fields arithmetic (𝔽p², 𝔽p⁴, 𝔽p⁶, 𝔽p¹², 𝔽p²⁴), +// - short Weierstrass curve arithmetic over G1 (E/𝔽p) and G2 (Eₜ/𝔽p² or Eₜ/𝔽p⁴) +// - twisted Edwards curve arithmetic +// +// These arithmetic operations are implemented +// - using native field via the 2-chains BLS12-377/BW6-761 and BLS24-315/BW-633 +// (`native/`) or associated twisted Edwards (e.g. Jubjub/BLS12-381) and +// - using nonnative field via field emulation (`emulated/`). This allows to +// use any curve over any (SNARK) field (e.g. secp256k1 curve arithmetic over +// BN254 SNARK field or BN254 pairing over BN254 SNARK field). The drawback +// of this approach is the extreme cost of the operations. +package algebra diff --git a/std/algebra/emulated/fields_bls12381/doc.go b/std/algebra/emulated/fields_bls12381/doc.go new file mode 100644 index 0000000000..c94c08b683 --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/doc.go @@ -0,0 +1,7 @@ +// Package fields_bls12381 implements the fields arithmetic of the Fp12 tower +// used to compute the pairing over the BLS12-381 curve. +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-1-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +package fields_bls12381 diff --git a/std/algebra/emulated/fields_bls12381/e12.go b/std/algebra/emulated/fields_bls12381/e12.go new file mode 100644 index 0000000000..d452c4f0cb --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e12.go @@ -0,0 +1,199 @@ +package fields_bls12381 + +import ( + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" +) + +type E12 struct { + C0, C1 E6 +} + +type Ext12 struct { + *Ext6 +} + +func NewExt12(api frontend.API) *Ext12 { + return &Ext12{Ext6: NewExt6(api)} +} + +func (e Ext12) Add(x, y *E12) *E12 { + z0 := e.Ext6.Add(&x.C0, &y.C0) + z1 := e.Ext6.Add(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Sub(x, y *E12) *E12 { + z0 := e.Ext6.Sub(&x.C0, &y.C0) + z1 := e.Ext6.Sub(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Conjugate(x *E12) *E12 { + z1 := e.Ext6.Neg(&x.C1) + return &E12{ + C0: x.C0, + C1: *z1, + } +} + +func (e Ext12) Mul(x, y *E12) *E12 { + a := e.Ext6.Add(&x.C0, &x.C1) + b := e.Ext6.Add(&y.C0, &y.C1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&x.C0, &y.C0) + c := e.Ext6.Mul(&x.C1, &y.C1) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Zero() *E12 { + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) One() *E12 { + z000 := e.fp.One() + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *z000, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) IsZero(z *E12) frontend.Variable { + c0 := e.Ext6.IsZero(&z.C0) + c1 := e.Ext6.IsZero(&z.C1) + return e.api.And(c0, c1) +} + +func (e Ext12) Square(x *E12) *E12 { + c0 := e.Ext6.Sub(&x.C0, &x.C1) + c3 := e.Ext6.MulByNonResidue(&x.C1) + c3 = e.Ext6.Neg(c3) + c3 = e.Ext6.Add(&x.C0, c3) + c2 := e.Ext6.Mul(&x.C0, &x.C1) + c0 = e.Ext6.Mul(c0, c3) + c0 = e.Ext6.Add(c0, c2) + z1 := e.Ext6.Double(c2) + c2 = e.Ext6.MulByNonResidue(c2) + z0 := e.Ext6.Add(c0, c2) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) AssertIsEqual(x, y *E12) { + e.Ext6.AssertIsEqual(&x.C0, &y.C0) + e.Ext6.AssertIsEqual(&x.C1, &y.C1) +} + +func FromE12(y *bls12381.E12) E12 { + return E12{ + C0: FromE6(&y.C0), + C1: FromE6(&y.C1), + } + +} + +func (e Ext12) Inverse(x *E12) *E12 { + res, err := e.fp.NewHint(inverseE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext12) DivUnchecked(x, y *E12) *E12 { + res, err := e.fp.NewHint(divE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1, &y.C0.B0.A0, &y.C0.B0.A1, &y.C0.B1.A0, &y.C0.B1.A1, &y.C0.B2.A0, &y.C0.B2.A1, &y.C1.B0.A0, &y.C1.B0.A1, &y.C1.B1.A0, &y.C1.B1.A1, &y.C1.B2.A0, &y.C1.B2.A1) + + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext12) Select(selector frontend.Variable, z1, z0 *E12) *E12 { + c0 := e.Ext6.Select(selector, &z1.C0, &z0.C0) + c1 := e.Ext6.Select(selector, &z1.C1, &z0.C1) + return &E12{C0: *c0, C1: *c1} +} + +func (e Ext12) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E12) *E12 { + c0 := e.Ext6.Lookup2(s1, s2, &a.C0, &b.C0, &c.C0, &d.C0) + c1 := e.Ext6.Lookup2(s1, s2, &a.C1, &b.C1, &c.C1, &d.C1) + return &E12{C0: *c0, C1: *c1} +} diff --git a/std/algebra/emulated/fields_bls12381/e12_pairing.go b/std/algebra/emulated/fields_bls12381/e12_pairing.go new file mode 100644 index 0000000000..c096b0d2fd --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e12_pairing.go @@ -0,0 +1,280 @@ +package fields_bls12381 + +import "github.com/consensys/gnark/std/math/emulated" + +func (e Ext12) nSquareTorus(z *E6, n int) *E6 { + for i := 0; i < n; i++ { + z = e.SquareTorus(z) + } + return z +} + +// ExptHalfTorus set z to x^(t/2) in E6 and return z +// const t/2 uint64 = 7566188111470821376 // negative +func (e Ext12) ExptHalfTorus(x *E6) *E6 { + // FixedExp computation is derived from the addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1101 = 1 + _1100 + // _1101000 = _1101 << 3 + // _1101001 = 1 + _1101000 + // return ((_1101001 << 9 + 1) << 32 + 1) << 15 + // + // Operations: 62 squares 5 multiplies + // + // Generated by github.com/mmcloughlin/addchain v0.4.0. + + // Step 1: z = x^0x2 + z := e.SquareTorus(x) + + // Step 2: z = x^0x3 + z = e.MulTorus(x, z) + + z = e.SquareTorus(z) + z = e.SquareTorus(z) + + // Step 5: z = x^0xd + z = e.MulTorus(x, z) + + // Step 8: z = x^0x68 + z = e.nSquareTorus(z, 3) + + // Step 9: z = x^0x69 + z = e.MulTorus(x, z) + + // Step 18: z = x^0xd200 + z = e.nSquareTorus(z, 9) + + // Step 19: z = x^0xd201 + z = e.MulTorus(x, z) + + // Step 51: z = x^0xd20100000000 + z = e.nSquareTorus(z, 32) + + // Step 52: z = x^0xd20100000001 + z = e.MulTorus(x, z) + + // Step 67: z = x^0x6900800000008000 + z = e.nSquareTorus(z, 15) + + z = e.InverseTorus(z) // because tAbsVal is negative + + return z +} + +// ExptTorus set z to xᵗ in E6 and return z +// const t uint64 = 15132376222941642752 // negative +func (e Ext12) ExptTorus(x *E6) *E6 { + z := e.ExptHalfTorus(x) + z = e.SquareTorus(z) + return z +} + +// MulBy014 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: 0}, +// C1: E6{B0: 0, B1: 1, B2: 0}, +// } +func (e *Ext12) MulBy014(z *E12, c0, c1 *E2) *E12 { + + a := z.C0 + a = *e.MulBy01(&a, c0, c1) + + var b E6 + // Mul by E6{0, 1, 0} + b.B0 = *e.Ext2.MulByNonResidue(&z.C1.B2) + b.B2 = z.C1.B1 + b.B1 = z.C1.B0 + + one := e.Ext2.One() + d := e.Ext2.Add(c1, one) + + zC1 := e.Ext6.Add(&z.C1, &z.C0) + zC1 = e.Ext6.MulBy01(zC1, c0, d) + zC1 = e.Ext6.Sub(zC1, &a) + zC1 = e.Ext6.Sub(zC1, &b) + zC0 := e.Ext6.MulByNonResidue(&b) + zC0 = e.Ext6.Add(zC0, &a) + + return &E12{ + C0: *zC0, + C1: *zC1, + } +} + +// multiplies two E12 sparse element of the form: +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: 0}, +// C1: E6{B0: 0, B1: 1, B2: 0}, +// } +// +// and +// +// E12{ +// C0: E6{B0: d0, B1: d1, B2: 0}, +// C1: E6{B0: 0, B1: 1, B2: 0}, +// } +func (e Ext12) Mul014By014(d0, d1, c0, c1 *E2) *[5]E2 { + one := e.Ext2.One() + x0 := e.Ext2.Mul(c0, d0) + x1 := e.Ext2.Mul(c1, d1) + tmp := e.Ext2.Add(c0, one) + x04 := e.Ext2.Add(d0, one) + x04 = e.Ext2.Mul(x04, tmp) + x04 = e.Ext2.Sub(x04, x0) + x04 = e.Ext2.Sub(x04, one) + tmp = e.Ext2.Add(c0, c1) + x01 := e.Ext2.Add(d0, d1) + x01 = e.Ext2.Mul(x01, tmp) + x01 = e.Ext2.Sub(x01, x0) + x01 = e.Ext2.Sub(x01, x1) + tmp = e.Ext2.Add(c1, one) + x14 := e.Ext2.Add(d1, one) + x14 = e.Ext2.Mul(x14, tmp) + x14 = e.Ext2.Sub(x14, x1) + x14 = e.Ext2.Sub(x14, one) + + zC0B0 := e.Ext2.NonResidue() + zC0B0 = e.Ext2.Add(zC0B0, x0) + + return &[5]E2{*zC0B0, *x01, *x1, *x04, *x14} +} + +// MulBy01245 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: c2}, +// C1: E6{B0: 0, B1: c4, B2: c5}, +// } +func (e *Ext12) MulBy01245(z *E12, x *[5]E2) *E12 { + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: *e.Ext2.Zero(), B1: x[3], B2: x[4]} + a := e.Ext6.Add(&z.C0, &z.C1) + b := e.Ext6.Add(c0, c1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&z.C0, c0) + c := e.Ext6.MulBy12(&z.C1, &x[3], &x[4]) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +// Torus-based arithmetic: +// +// After the easy part of the final exponentiation the elements are in a proper +// subgroup of Fpk (E12) that coincides with some algebraic tori. The elements +// are in the torus Tk(Fp) and thus in each torus Tk/d(Fp^d) for d|k, d≠k. We +// take d=6. So the elements are in T2(Fp6). +// Let G_{q,2} = {m ∈ Fq^2 | m^(q+1) = 1} where q = p^6. +// When m.C1 = 0, then m.C0 must be 1 or −1. +// +// We recall the tower construction: +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-1-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v + +// CompressTorus compresses x ∈ E12 to (x.C0 + 1)/x.C1 ∈ E6 +func (e Ext12) CompressTorus(x *E12) *E6 { + // x ∈ G_{q,2} \ {-1,1} + y := e.Ext6.Add(&x.C0, e.Ext6.One()) + y = e.Ext6.DivUnchecked(y, &x.C1) + return y +} + +// DecompressTorus decompresses y ∈ E6 to (y+w)/(y-w) ∈ E12 +func (e Ext12) DecompressTorus(y *E6) *E12 { + var n, d E12 + one := e.Ext6.One() + n.C0 = *y + n.C1 = *one + d.C0 = *y + d.C1 = *e.Ext6.Neg(one) + + x := e.DivUnchecked(&n, &d) + return x +} + +// MulTorus multiplies two compressed elements y1, y2 ∈ E6 +// and returns (y1 * y2 + v)/(y1 + y2) +// N.B.: we use MulTorus in the final exponentiation throughout y1 ≠ -y2 always. +func (e Ext12) MulTorus(y1, y2 *E6) *E6 { + n := e.Ext6.Mul(y1, y2) + n.B1 = *e.Ext2.Add(&n.B1, e.Ext2.One()) + d := e.Ext6.Add(y1, y2) + y3 := e.Ext6.DivUnchecked(n, d) + return y3 +} + +// InverseTorus inverses a compressed elements y ∈ E6 +// and returns -y +func (e Ext12) InverseTorus(y *E6) *E6 { + return e.Ext6.Neg(y) +} + +// SquareTorus squares a compressed elements y ∈ E6 +// and returns (y + v/y)/2 +// +// It uses a hint to verify that (2x-y)y = v saving one E6 AssertIsEqual. +func (e Ext12) SquareTorus(y *E6) *E6 { + res, err := e.fp.NewHint(squareTorusHint, 6, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + sq := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // v = (2x-y)y + v := e.Ext6.Double(&sq) + v = e.Ext6.Sub(v, y) + v = e.Ext6.Mul(v, y) + + _v := E6{B0: *e.Ext2.Zero(), B1: *e.Ext2.One(), B2: *e.Ext2.Zero()} + e.Ext6.AssertIsEqual(v, &_v) + + return &sq + +} + +// FrobeniusTorus raises a compressed elements y ∈ E6 to the modulus p +// and returns y^p / v^((p-1)/2) +func (e Ext12) FrobeniusTorus(y *E6) *E6 { + t0 := e.Ext2.Conjugate(&y.B0) + t1 := e.Ext2.Conjugate(&y.B1) + t2 := e.Ext2.Conjugate(&y.B2) + t1 = e.Ext2.MulByNonResidue1Power2(t1) + t2 = e.Ext2.MulByNonResidue1Power4(t2) + + v0 := E2{emulated.ValueOf[emulated.BLS12381Fp]("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230"), emulated.ValueOf[emulated.BLS12381Fp]("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230")} + res := &E6{B0: *t0, B1: *t1, B2: *t2} + res = e.Ext6.MulBy0(res, &v0) + + return res +} + +// FrobeniusSquareTorus raises a compressed elements y ∈ E6 to the square modulus p^2 +// and returns y^(p^2) / v^((p^2-1)/2) +func (e Ext12) FrobeniusSquareTorus(y *E6) *E6 { + v0 := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + t0 := e.Ext2.MulByElement(&y.B0, &v0) + t1 := e.Ext2.MulByNonResidue2Power2(&y.B1) + t1 = e.Ext2.MulByElement(t1, &v0) + t2 := e.Ext2.MulByNonResidue2Power4(&y.B2) + t2 = e.Ext2.MulByElement(t2, &v0) + + return &E6{B0: *t0, B1: *t1, B2: *t2} +} diff --git a/std/algebra/emulated/fields_bls12381/e12_test.go b/std/algebra/emulated/fields_bls12381/e12_test.go new file mode 100644 index 0000000000..cba62ef2be --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e12_test.go @@ -0,0 +1,581 @@ +package fields_bls12381 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e12Add struct { + A, B, C E12 +} + +func (circuit *e12Add) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e12Add{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Sub struct { + A, B, C E12 +} + +func (circuit *e12Sub) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e12Sub{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Mul struct { + A, B, C E12 +} + +func (circuit *e12Mul) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e12Mul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Div struct { + A, B, C E12 +} + +func (circuit *e12Div) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e12Div{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Square struct { + A, C E12 +} + +func (circuit *e12Square) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e12Square{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Conjugate struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Conjugate) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e12Conjugate{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12Inverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Inverse) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e12Inverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12ExptTorus struct { + A E6 + C E12 `gnark:",public"` +} + +func (circuit *e12ExptTorus) Define(api frontend.API) error { + e := NewExt12(api) + z := e.ExptTorus(&circuit.A) + expected := e.DecompressTorus(z) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestFp12ExptTorus(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bls12381.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c.Expt(&a) + _a, _ := a.CompressTorus() + witness := e12ExptTorus{ + A: FromE6(&_a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12ExptTorus{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12MulBy014 struct { + A E12 `gnark:",public"` + W E12 + B, C E2 +} + +func (circuit *e12MulBy014) Define(api frontend.API) error { + e := NewExt12(api) + res := e.MulBy014(&circuit.A, &circuit.B, &circuit.C) + e.AssertIsEqual(res, &circuit.W) + return nil +} + +func TestFp12MulBy014(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, w bls12381.E12 + _, _ = a.SetRandom() + var one, b, c bls12381.E2 + one.SetOne() + _, _ = b.SetRandom() + _, _ = c.SetRandom() + w.Set(&a) + w.MulBy014(&b, &c, &one) + + witness := e12MulBy014{ + A: FromE12(&a), + B: FromE2(&b), + C: FromE2(&c), + W: FromE12(&w), + } + + err := test.IsSolved(&e12MulBy014{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +// Torus-based arithmetic +type torusCompress struct { + A E12 + C E6 `gnark:",public"` +} + +func (circuit *torusCompress) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.CompressTorus(&circuit.A) + e.Ext6.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusCompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bls12381.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c, _ := a.CompressTorus() + + witness := torusCompress{ + A: FromE12(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&torusCompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusDecompress struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusDecompress) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusDecompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bls12381.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + d, _ := a.CompressTorus() + c := d.DecompressTorus() + + witness := torusDecompress{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusDecompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusMul struct { + A E12 + B E12 + C E12 `gnark:",public"` +} + +func (circuit *torusMul) Define(api frontend.API) error { + e := NewExt12(api) + compressedA := e.CompressTorus(&circuit.A) + compressedB := e.CompressTorus(&circuit.B) + compressedAB := e.MulTorus(compressedA, compressedB) + expected := e.DecompressTorus(compressedAB) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusMul(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c, tmp bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + // put b in the cyclotomic subgroup + tmp.Conjugate(&b) + b.Inverse(&b) + tmp.Mul(&tmp, &b) + b.FrobeniusSquare(&tmp).Mul(&b, &tmp) + + // uncompressed mul + c.Mul(&a, &b) + + witness := torusMul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&torusMul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusInverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusInverse) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.InverseTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusInverse(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed inverse + c.Inverse(&a) + + witness := torusInverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusInverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobenius struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobenius) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobenius(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobenius + c.Frobenius(&a) + + witness := torusFrobenius{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobenius{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobeniusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobeniusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusSquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobeniusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobeniusSquare + c.FrobeniusSquare(&a) + + witness := torusFrobeniusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobeniusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.SquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed square + c.Square(&a) + + witness := torusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bls12381/e2.go b/std/algebra/emulated/fields_bls12381/e2.go new file mode 100644 index 0000000000..e36178b215 --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e2.go @@ -0,0 +1,344 @@ +package fields_bls12381 + +import ( + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +type curveF = emulated.Field[emulated.BLS12381Fp] +type baseEl = emulated.Element[emulated.BLS12381Fp] + +type E2 struct { + A0, A1 baseEl +} + +type Ext2 struct { + api frontend.API + fp *curveF + nonResidues map[int]map[int]*E2 +} + +func NewExt2(api frontend.API) *Ext2 { + fp, err := emulated.NewField[emulated.BLS12381Fp](api) + if err != nil { + panic(err) + } + pwrs := map[int]map[int]struct { + A0 string + A1 string + }{ + 1: { + 1: {"3850754370037169011952147076051364057158807420970682438676050522613628423219637725072182697113062777891589506424760", "151655185184498381465642749684540099398075398968325446656007613510403227271200139370504932015952886146304766135027"}, + 2: {"0", "4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436"}, + 3: {"1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257", "1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257"}, + 4: {"4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437", "0"}, + 5: {"877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230", "3125332594171059424908108096204648978570118281977575435832422631601824034463382777937621250592425535493320683825557"}, + }, + 2: { + 1: {"793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351", "0"}, + 2: {"793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350", "0"}, + 3: {"4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786", "0"}, + 4: {"4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436", "0"}, + 5: {"4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437", "0"}, + }, + } + nonResidues := make(map[int]map[int]*E2) + for pwr, v := range pwrs { + for coeff, v := range v { + el := E2{emulated.ValueOf[emulated.BLS12381Fp](v.A0), emulated.ValueOf[emulated.BLS12381Fp](v.A1)} + if nonResidues[pwr] == nil { + nonResidues[pwr] = make(map[int]*E2) + } + nonResidues[pwr][coeff] = &el + } + } + return &Ext2{api: api, fp: fp, nonResidues: nonResidues} +} + +func (e Ext2) MulByElement(x *E2, y *baseEl) *E2 { + z0 := e.fp.MulMod(&x.A0, y) + z1 := e.fp.MulMod(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) MulByConstElement(x *E2, y *big.Int) *E2 { + z0 := e.fp.MulConst(&x.A0, y) + z1 := e.fp.MulConst(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Conjugate(x *E2) *E2 { + z0 := x.A0 + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: z0, + A1: *z1, + } +} + +func (e Ext2) MulByNonResidueGeneric(x *E2, power, coef int) *E2 { + y := e.nonResidues[power][coef] + z := e.Mul(x, y) + return z +} + +// MulByNonResidue returns x*(1+u) +func (e Ext2) MulByNonResidue(x *E2) *E2 { + a := e.fp.Sub(&x.A0, &x.A1) + b := e.fp.Add(&x.A0, &x.A1) + + return &E2{ + A0: *a, + A1: *b, + } +} + +// MulByNonResidue1Power1 returns x*(1+u)^(1*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power1(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 1) +} + +// MulByNonResidue1Power2 returns x*(1+u)^(2*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power2(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + a := e.fp.MulMod(&x.A1, &element) + a = e.fp.Neg(a) + b := e.fp.MulMod(&x.A0, &element) + return &E2{ + A0: *a, + A1: *b, + } +} + +// MulByNonResidue1Power3 returns x*(1+u)^(3*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power3(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 3) +} + +// MulByNonResidue1Power4 returns x*(1+u)^(4*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power4(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue1Power5 returns x*(1+u)^(5*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power5(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 5) +} + +// MulByNonResidue2Power1 returns x*(1+u)^(1*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power1(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power2 returns x*(1+u)^(2*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power2(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power3 returns x*(1+u)^(3*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power3(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power4 returns x*(1+u)^(4*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power4(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power5 returns x*(1+u)^(5*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power5(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +func (e Ext2) Mul(x, y *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Add(&y.A0, &y.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &y.A0) + c := e.fp.MulMod(&x.A1, &y.A1) + z1 := e.fp.Sub(a, b) + z1 = e.fp.Sub(z1, c) + z0 := e.fp.Sub(b, c) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Add(x, y *E2) *E2 { + z0 := e.fp.Add(&x.A0, &y.A0) + z1 := e.fp.Add(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Sub(x, y *E2) *E2 { + z0 := e.fp.Sub(&x.A0, &y.A0) + z1 := e.fp.Sub(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Neg(x *E2) *E2 { + z0 := e.fp.Neg(&x.A0) + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) One() *E2 { + z0 := e.fp.One() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Zero() *E2 { + z0 := e.fp.Zero() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} +func (e Ext2) IsZero(z *E2) frontend.Variable { + a0 := e.fp.IsZero(&z.A0) + a1 := e.fp.IsZero(&z.A1) + return e.api.And(a0, a1) +} + +// returns 1+u +func (e Ext2) NonResidue() *E2 { + one := e.fp.One() + return &E2{ + A0: *one, + A1: *one, + } +} + +func (e Ext2) Square(x *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Sub(&x.A0, &x.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &x.A1) + b = e.fp.MulConst(b, big.NewInt(2)) + return &E2{ + A0: *a, + A1: *b, + } +} + +func (e Ext2) Double(x *E2) *E2 { + two := big.NewInt(2) + z0 := e.fp.MulConst(&x.A0, two) + z1 := e.fp.MulConst(&x.A1, two) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) AssertIsEqual(x, y *E2) { + e.fp.AssertIsEqual(&x.A0, &y.A0) + e.fp.AssertIsEqual(&x.A1, &y.A1) +} + +func FromE2(y *bls12381.E2) E2 { + return E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](y.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](y.A1), + } +} + +func (e Ext2) Inverse(x *E2) *E2 { + res, err := e.fp.NewHint(inverseE2Hint, 2, &x.A0, &x.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E2{ + A0: *res[0], + A1: *res[1], + } + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext2) DivUnchecked(x, y *E2) *E2 { + res, err := e.fp.NewHint(divE2Hint, 2, &x.A0, &x.A1, &y.A0, &y.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E2{ + A0: *res[0], + A1: *res[1], + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext2) Select(selector frontend.Variable, z1, z0 *E2) *E2 { + a0 := e.fp.Select(selector, &z1.A0, &z0.A0) + a1 := e.fp.Select(selector, &z1.A1, &z0.A1) + return &E2{A0: *a0, A1: *a1} +} + +func (e Ext2) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E2) *E2 { + a0 := e.fp.Lookup2(s1, s2, &a.A0, &b.A0, &c.A0, &d.A0) + a1 := e.fp.Lookup2(s1, s2, &a.A1, &b.A1, &c.A1, &d.A1) + return &E2{A0: *a0, A1: *a1} +} diff --git a/std/algebra/emulated/fields_bls12381/e2_test.go b/std/algebra/emulated/fields_bls12381/e2_test.go new file mode 100644 index 0000000000..4065af437c --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e2_test.go @@ -0,0 +1,351 @@ +package fields_bls12381 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +type e2Add struct { + A, B, C E2 +} + +func (circuit *e2Add) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e2Add{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Sub struct { + A, B, C E2 +} + +func (circuit *e2Sub) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e2Sub{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Double struct { + A, C E2 +} + +func (circuit *e2Double) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Double(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDoubleFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Double(&a) + + witness := e2Double{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Double{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Mul struct { + A, B, C E2 +} + +func (circuit *e2Mul) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e2Mul{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Square struct { + A, C E2 +} + +func (circuit *e2Square) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e2Square{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Div struct { + A, B, C E2 +} + +func (circuit *e2Div) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e2Div{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByElement struct { + A E2 + B baseEl + C E2 `gnark:",public"` +} + +func (circuit *e2MulByElement) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByElement(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulByElement(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + var b fp.Element + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByElement(&a, &b) + + witness := e2MulByElement{ + A: FromE2(&a), + B: emulated.ValueOf[emulated.BLS12381Fp](b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByElement{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByNonResidue struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2MulByNonResidue) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp2ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e2MulByNonResidue{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Neg struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Neg) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e2Neg{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Conjugate struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Conjugate) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e2Conjugate{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Inverse struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Inverse) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e2Inverse{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bls12381/e6.go b/std/algebra/emulated/fields_bls12381/e6.go new file mode 100644 index 0000000000..85e522eeb4 --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e6.go @@ -0,0 +1,315 @@ +package fields_bls12381 + +import ( + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" +) + +type E6 struct { + B0, B1, B2 E2 +} + +type Ext6 struct { + *Ext2 +} + +func NewExt6(api frontend.API) *Ext6 { + return &Ext6{Ext2: NewExt2(api)} +} + +func (e Ext6) One() *E6 { + z0 := e.Ext2.One() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Zero() *E6 { + z0 := e.Ext2.Zero() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) IsZero(z *E6) frontend.Variable { + b0 := e.Ext2.IsZero(&z.B0) + b1 := e.Ext2.IsZero(&z.B1) + b2 := e.Ext2.IsZero(&z.B2) + return e.api.And(e.api.And(b0, b1), b2) +} + +func (e Ext6) Add(x, y *E6) *E6 { + z0 := e.Ext2.Add(&x.B0, &y.B0) + z1 := e.Ext2.Add(&x.B1, &y.B1) + z2 := e.Ext2.Add(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Neg(x *E6) *E6 { + z0 := e.Ext2.Neg(&x.B0) + z1 := e.Ext2.Neg(&x.B1) + z2 := e.Ext2.Neg(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Sub(x, y *E6) *E6 { + z0 := e.Ext2.Sub(&x.B0, &y.B0) + z1 := e.Ext2.Sub(&x.B1, &y.B1) + z2 := e.Ext2.Sub(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Mul(x, y *E6) *E6 { + t0 := e.Ext2.Mul(&x.B0, &y.B0) + t1 := e.Ext2.Mul(&x.B1, &y.B1) + t2 := e.Ext2.Mul(&x.B2, &y.B2) + c0 := e.Ext2.Add(&x.B1, &x.B2) + tmp := e.Ext2.Add(&y.B1, &y.B2) + c0 = e.Ext2.Mul(c0, tmp) + c0 = e.Ext2.Sub(c0, t1) + c0 = e.Ext2.Sub(c0, t2) + c0 = e.Ext2.MulByNonResidue(c0) + c0 = e.Ext2.Add(c0, t0) + c1 := e.Ext2.Add(&x.B0, &x.B1) + tmp = e.Ext2.Add(&y.B0, &y.B1) + c1 = e.Ext2.Mul(c1, tmp) + c1 = e.Ext2.Sub(c1, t0) + c1 = e.Ext2.Sub(c1, t1) + tmp = e.Ext2.MulByNonResidue(t2) + c1 = e.Ext2.Add(c1, tmp) + tmp = e.Ext2.Add(&x.B0, &x.B2) + c2 := e.Ext2.Add(&y.B0, &y.B2) + c2 = e.Ext2.Mul(c2, tmp) + c2 = e.Ext2.Sub(c2, t0) + c2 = e.Ext2.Sub(c2, t2) + c2 = e.Ext2.Add(c2, t1) + return &E6{ + B0: *c0, + B1: *c1, + B2: *c2, + } +} + +func (e Ext6) Double(x *E6) *E6 { + z0 := e.Ext2.Double(&x.B0) + z1 := e.Ext2.Double(&x.B1) + z2 := e.Ext2.Double(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Square(x *E6) *E6 { + c4 := e.Ext2.Mul(&x.B0, &x.B1) + c4 = e.Ext2.Double(c4) + c5 := e.Ext2.Square(&x.B2) + c1 := e.Ext2.MulByNonResidue(c5) + c1 = e.Ext2.Add(c1, c4) + c2 := e.Ext2.Sub(c4, c5) + c3 := e.Ext2.Square(&x.B0) + c4 = e.Ext2.Sub(&x.B0, &x.B1) + c4 = e.Ext2.Add(c4, &x.B2) + c5 = e.Ext2.Mul(&x.B1, &x.B2) + c5 = e.Ext2.Double(c5) + c4 = e.Ext2.Square(c4) + c0 := e.Ext2.MulByNonResidue(c5) + c0 = e.Ext2.Add(c0, c3) + z2 := e.Ext2.Add(c2, c4) + z2 = e.Ext2.Add(z2, c5) + z2 = e.Ext2.Sub(z2, c3) + z0 := c0 + z1 := c1 + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) MulByE2(x *E6, y *E2) *E6 { + z0 := e.Ext2.Mul(&x.B0, y) + z1 := e.Ext2.Mul(&x.B1, y) + z2 := e.Ext2.Mul(&x.B2, y) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +// MulBy12 multiplication by sparse element (0,b1,b2) +func (e Ext6) MulBy12(x *E6, b1, b2 *E2) *E6 { + t1 := e.Ext2.Mul(&x.B1, b1) + t2 := e.Ext2.Mul(&x.B2, b2) + c0 := e.Ext2.Add(&x.B1, &x.B2) + tmp := e.Ext2.Add(b1, b2) + c0 = e.Ext2.Mul(c0, tmp) + c0 = e.Ext2.Sub(c0, t1) + c0 = e.Ext2.Sub(c0, t2) + c0 = e.Ext2.MulByNonResidue(c0) + c1 := e.Ext2.Add(&x.B0, &x.B1) + c1 = e.Ext2.Mul(c1, b1) + c1 = e.Ext2.Sub(c1, t1) + tmp = e.Ext2.MulByNonResidue(t2) + c1 = e.Ext2.Add(c1, tmp) + tmp = e.Ext2.Add(&x.B0, &x.B2) + c2 := e.Ext2.Mul(b2, tmp) + c2 = e.Ext2.Sub(c2, t2) + c2 = e.Ext2.Add(c2, t1) + return &E6{ + B0: *c0, + B1: *c1, + B2: *c2, + } +} + +// MulBy0 multiplies z by an E6 sparse element of the form +// +// E6{ +// B0: c0, +// B1: 0, +// B2: 0, +// } +func (e Ext6) MulBy0(z *E6, c0 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + tmp := e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 := e.Ext2.Mul(c0, tmp) + t1 = e.Ext2.Sub(t1, a) + return &E6{ + B0: *a, + B1: *t1, + B2: *t2, + } +} + +// MulBy01 multiplication by sparse element (c0,c1,0) +func (e Ext6) MulBy01(z *E6, c0, c1 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + b := e.Ext2.Mul(&z.B1, c1) + tmp := e.Ext2.Add(&z.B1, &z.B2) + t0 := e.Ext2.Mul(c1, tmp) + t0 = e.Ext2.Sub(t0, b) + t0 = e.Ext2.MulByNonResidue(t0) + t0 = e.Ext2.Add(t0, a) + tmp = e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + t2 = e.Ext2.Add(t2, b) + t1 := e.Ext2.Add(c0, c1) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 = e.Ext2.Mul(t1, tmp) + t1 = e.Ext2.Sub(t1, a) + t1 = e.Ext2.Sub(t1, b) + return &E6{ + B0: *t0, + B1: *t1, + B2: *t2, + } +} + +func (e Ext6) MulByNonResidue(x *E6) *E6 { + z2, z1, z0 := &x.B1, &x.B0, &x.B2 + z0 = e.Ext2.MulByNonResidue(z0) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) AssertIsEqual(x, y *E6) { + e.Ext2.AssertIsEqual(&x.B0, &y.B0) + e.Ext2.AssertIsEqual(&x.B1, &y.B1) + e.Ext2.AssertIsEqual(&x.B2, &y.B2) +} + +func FromE6(y *bls12381.E6) E6 { + return E6{ + B0: FromE2(&y.B0), + B1: FromE2(&y.B1), + B2: FromE2(&y.B2), + } + +} + +func (e Ext6) Inverse(x *E6) *E6 { + res, err := e.fp.NewHint(inverseE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext6) DivUnchecked(x, y *E6) *E6 { + res, err := e.fp.NewHint(divE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext6) Select(selector frontend.Variable, z1, z0 *E6) *E6 { + b0 := e.Ext2.Select(selector, &z1.B0, &z0.B0) + b1 := e.Ext2.Select(selector, &z1.B1, &z0.B1) + b2 := e.Ext2.Select(selector, &z1.B2, &z0.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} + +func (e Ext6) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E6) *E6 { + b0 := e.Ext2.Lookup2(s1, s2, &a.B0, &b.B0, &c.B0, &d.B0) + b1 := e.Ext2.Lookup2(s1, s2, &a.B1, &b.B1, &c.B1, &d.B1) + b2 := e.Ext2.Lookup2(s1, s2, &a.B2, &b.B2, &c.B2, &d.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} diff --git a/std/algebra/emulated/fields_bls12381/e6_test.go b/std/algebra/emulated/fields_bls12381/e6_test.go new file mode 100644 index 0000000000..7964833bdf --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e6_test.go @@ -0,0 +1,327 @@ +package fields_bls12381 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e6Add struct { + A, B, C E6 +} + +func (circuit *e6Add) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e6Add{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Sub struct { + A, B, C E6 +} + +func (circuit *e6Sub) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e6Sub{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Mul struct { + A, B, C E6 +} + +func (circuit *e6Mul) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e6Mul{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Square struct { + A, C E6 +} + +func (circuit *e6Square) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e6Square{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Div struct { + A, B, C E6 +} + +func (circuit *e6Div) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e6Div{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByNonResidue struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByNonResidue) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e6MulByNonResidue{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByE2 struct { + A E6 + B E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByE2) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByE2(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByE2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + var b bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByE2(&a, &b) + + witness := e6MulByE2{ + A: FromE6(&a), + B: FromE2(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByE2{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulBy01 struct { + A E6 + C0, C1 E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulBy01) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulBy01(&circuit.A, &circuit.C0, &circuit.C1) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6By01(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + var C0, C1 bls12381.E2 + _, _ = a.SetRandom() + _, _ = C0.SetRandom() + _, _ = C1.SetRandom() + c.Set(&a) + c.MulBy01(&C0, &C1) + + witness := e6MulBy01{ + A: FromE6(&a), + C0: FromE2(&C0), + C1: FromE2(&C1), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulBy01{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Neg struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Neg) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e6Neg{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e6Inverse struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Inverse) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e6Inverse{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bls12381/hints.go b/std/algebra/emulated/fields_bls12381/hints.go new file mode 100644 index 0000000000..c8455f40cf --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/hints.go @@ -0,0 +1,238 @@ +package fields_bls12381 + +import ( + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/math/emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + // E2 + divE2Hint, + inverseE2Hint, + // E6 + divE6Hint, + inverseE6Hint, + squareTorusHint, + // E12 + divE12Hint, + inverseE12Hint, + } +} + +func inverseE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + + c.Inverse(&a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +func divE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bls12381.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + b.A0.SetBigInt(inputs[2]) + b.A1.SetBigInt(inputs[3]) + + c.Inverse(&b).Mul(&c, &a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +// E6 hints +func inverseE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + c.Inverse(&a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func divE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bls12381.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + b.B0.A0.SetBigInt(inputs[6]) + b.B0.A1.SetBigInt(inputs[7]) + b.B1.A0.SetBigInt(inputs[8]) + b.B1.A1.SetBigInt(inputs[9]) + b.B2.A0.SetBigInt(inputs[10]) + b.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&b).Mul(&c, &a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func squareTorusHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + _c := a.DecompressTorus() + _c.CyclotomicSquare(&_c) + c, _ = _c.CompressTorus() + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +// E12 hints +func inverseE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} + +func divE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bls12381.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + b.C0.B0.A0.SetBigInt(inputs[12]) + b.C0.B0.A1.SetBigInt(inputs[13]) + b.C0.B1.A0.SetBigInt(inputs[14]) + b.C0.B1.A1.SetBigInt(inputs[15]) + b.C0.B2.A0.SetBigInt(inputs[16]) + b.C0.B2.A1.SetBigInt(inputs[17]) + b.C1.B0.A0.SetBigInt(inputs[18]) + b.C1.B0.A1.SetBigInt(inputs[19]) + b.C1.B1.A0.SetBigInt(inputs[20]) + b.C1.B1.A1.SetBigInt(inputs[21]) + b.C1.B2.A0.SetBigInt(inputs[22]) + b.C1.B2.A1.SetBigInt(inputs[23]) + + c.Inverse(&b).Mul(&c, &a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} diff --git a/std/algebra/emulated/fields_bn254/doc.go b/std/algebra/emulated/fields_bn254/doc.go new file mode 100644 index 0000000000..ae94b3c243 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/doc.go @@ -0,0 +1,7 @@ +// Package fields_bn254 implements the fields arithmetic of the Fp12 tower +// used to compute the pairing over the BN254 curve. +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-9-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +package fields_bn254 diff --git a/std/algebra/emulated/fields_bn254/e12.go b/std/algebra/emulated/fields_bn254/e12.go new file mode 100644 index 0000000000..abe8bd9a9e --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e12.go @@ -0,0 +1,199 @@ +package fields_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" +) + +type E12 struct { + C0, C1 E6 +} + +type Ext12 struct { + *Ext6 +} + +func NewExt12(api frontend.API) *Ext12 { + return &Ext12{Ext6: NewExt6(api)} +} + +func (e Ext12) Add(x, y *E12) *E12 { + z0 := e.Ext6.Add(&x.C0, &y.C0) + z1 := e.Ext6.Add(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Sub(x, y *E12) *E12 { + z0 := e.Ext6.Sub(&x.C0, &y.C0) + z1 := e.Ext6.Sub(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Conjugate(x *E12) *E12 { + z1 := e.Ext6.Neg(&x.C1) + return &E12{ + C0: x.C0, + C1: *z1, + } +} + +func (e Ext12) Mul(x, y *E12) *E12 { + a := e.Ext6.Add(&x.C0, &x.C1) + b := e.Ext6.Add(&y.C0, &y.C1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&x.C0, &y.C0) + c := e.Ext6.Mul(&x.C1, &y.C1) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Zero() *E12 { + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) One() *E12 { + z000 := e.fp.One() + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *z000, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) IsZero(z *E12) frontend.Variable { + c0 := e.Ext6.IsZero(&z.C0) + c1 := e.Ext6.IsZero(&z.C1) + return e.api.And(c0, c1) +} + +func (e Ext12) Square(x *E12) *E12 { + c0 := e.Ext6.Sub(&x.C0, &x.C1) + c3 := e.Ext6.MulByNonResidue(&x.C1) + c3 = e.Ext6.Neg(c3) + c3 = e.Ext6.Add(&x.C0, c3) + c2 := e.Ext6.Mul(&x.C0, &x.C1) + c0 = e.Ext6.Mul(c0, c3) + c0 = e.Ext6.Add(c0, c2) + z1 := e.Ext6.Double(c2) + c2 = e.Ext6.MulByNonResidue(c2) + z0 := e.Ext6.Add(c0, c2) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) AssertIsEqual(x, y *E12) { + e.Ext6.AssertIsEqual(&x.C0, &y.C0) + e.Ext6.AssertIsEqual(&x.C1, &y.C1) +} + +func FromE12(y *bn254.E12) E12 { + return E12{ + C0: FromE6(&y.C0), + C1: FromE6(&y.C1), + } + +} + +func (e Ext12) Inverse(x *E12) *E12 { + res, err := e.fp.NewHint(inverseE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext12) DivUnchecked(x, y *E12) *E12 { + res, err := e.fp.NewHint(divE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1, &y.C0.B0.A0, &y.C0.B0.A1, &y.C0.B1.A0, &y.C0.B1.A1, &y.C0.B2.A0, &y.C0.B2.A1, &y.C1.B0.A0, &y.C1.B0.A1, &y.C1.B1.A0, &y.C1.B1.A1, &y.C1.B2.A0, &y.C1.B2.A1) + + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext12) Select(selector frontend.Variable, z1, z0 *E12) *E12 { + c0 := e.Ext6.Select(selector, &z1.C0, &z0.C0) + c1 := e.Ext6.Select(selector, &z1.C1, &z0.C1) + return &E12{C0: *c0, C1: *c1} +} + +func (e Ext12) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E12) *E12 { + c0 := e.Ext6.Lookup2(s1, s2, &a.C0, &b.C0, &c.C0, &d.C0) + c1 := e.Ext6.Lookup2(s1, s2, &a.C1, &b.C1, &c.C1, &d.C1) + return &E12{C0: *c0, C1: *c1} +} diff --git a/std/algebra/emulated/fields_bn254/e12_pairing.go b/std/algebra/emulated/fields_bn254/e12_pairing.go new file mode 100644 index 0000000000..ecec07bb2f --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e12_pairing.go @@ -0,0 +1,312 @@ +package fields_bn254 + +import ( + "github.com/consensys/gnark/std/math/emulated" +) + +func (e Ext12) nSquareTorus(z *E6, n int) *E6 { + for i := 0; i < n; i++ { + z = e.SquareTorus(z) + } + return z +} + +// Exponentiation by the seed t=4965661367192848881 +// The computations are performed on E6 compressed form using Torus-based arithmetic. +func (e Ext12) ExptTorus(x *E6) *E6 { + // Expt computation is derived from the addition chain: + // + // _10 = 2*1 + // _100 = 2*_10 + // _1000 = 2*_100 + // _10000 = 2*_1000 + // _10001 = 1 + _10000 + // _10011 = _10 + _10001 + // _10100 = 1 + _10011 + // _11001 = _1000 + _10001 + // _100010 = 2*_10001 + // _100111 = _10011 + _10100 + // _101001 = _10 + _100111 + // i27 = (_100010 << 6 + _100 + _11001) << 7 + _11001 + // i44 = (i27 << 8 + _101001 + _10) << 6 + _10001 + // i70 = ((i44 << 8 + _101001) << 6 + _101001) << 10 + // return (_100111 + i70) << 6 + _101001 + _1000 + // + // Operations: 62 squares 17 multiplies + // + // Generated by github.com/mmcloughlin/addchain v0.4.0. + + t3 := e.SquareTorus(x) + t5 := e.SquareTorus(t3) + result := e.SquareTorus(t5) + t0 := e.SquareTorus(result) + t2 := e.MulTorus(x, t0) + t0 = e.MulTorus(t3, t2) + t1 := e.MulTorus(x, t0) + t4 := e.MulTorus(result, t2) + t6 := e.SquareTorus(t2) + t1 = e.MulTorus(t0, t1) + t0 = e.MulTorus(t3, t1) + t6 = e.nSquareTorus(t6, 6) + t5 = e.MulTorus(t5, t6) + t5 = e.MulTorus(t4, t5) + t5 = e.nSquareTorus(t5, 7) + t4 = e.MulTorus(t4, t5) + t4 = e.nSquareTorus(t4, 8) + t4 = e.MulTorus(t0, t4) + t3 = e.MulTorus(t3, t4) + t3 = e.nSquareTorus(t3, 6) + t2 = e.MulTorus(t2, t3) + t2 = e.nSquareTorus(t2, 8) + t2 = e.MulTorus(t0, t2) + t2 = e.nSquareTorus(t2, 6) + t2 = e.MulTorus(t0, t2) + t2 = e.nSquareTorus(t2, 10) + t1 = e.MulTorus(t1, t2) + t1 = e.nSquareTorus(t1, 6) + t0 = e.MulTorus(t0, t1) + z := e.MulTorus(result, t0) + return z +} + +// MulBy034 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +func (e *Ext12) MulBy034(z *E12, c3, c4 *E2) *E12 { + + a := z.C0 + b := z.C1 + b = *e.MulBy01(&b, c3, c4) + c3 = e.Ext2.Add(e.Ext2.One(), c3) + d := e.Ext6.Add(&z.C0, &z.C1) + d = e.MulBy01(d, c3, c4) + + zC1 := e.Ext6.Add(&a, &b) + zC1 = e.Ext6.Neg(zC1) + zC1 = e.Ext6.Add(zC1, d) + zC0 := e.Ext6.MulByNonResidue(&b) + zC0 = e.Ext6.Add(zC0, &a) + + return &E12{ + C0: *zC0, + C1: *zC1, + } +} + +// multiplies two E12 sparse element of the form: +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +// +// and +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: d3, B1: d4, B2: 0}, +// } +func (e *Ext12) Mul034By034(d3, d4, c3, c4 *E2) *[5]E2 { + x3 := e.Ext2.Mul(c3, d3) + x4 := e.Ext2.Mul(c4, d4) + x04 := e.Ext2.Add(c4, d4) + x03 := e.Ext2.Add(c3, d3) + tmp := e.Ext2.Add(c3, c4) + x34 := e.Ext2.Add(d3, d4) + x34 = e.Ext2.Mul(x34, tmp) + x34 = e.Ext2.Sub(x34, x3) + x34 = e.Ext2.Sub(x34, x4) + + zC0B0 := e.Ext2.MulByNonResidue(x4) + zC0B0 = e.Ext2.Add(zC0B0, e.Ext2.One()) + zC0B1 := x3 + zC0B2 := x34 + zC1B0 := x03 + zC1B1 := x04 + + return &[5]E2{*zC0B0, *zC0B1, *zC0B2, *zC1B0, *zC1B1} +} + +// MulBy01234 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: c2}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +func (e *Ext12) MulBy01234(z *E12, x *[5]E2) *E12 { + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: *e.Ext2.Zero()} + a := e.Ext6.Add(&z.C0, &z.C1) + b := e.Ext6.Add(c0, c1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&z.C0, c0) + c := e.Ext6.MulBy01(&z.C1, &x[3], &x[4]) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +// multiplies two E12 sparse element of the form: +// +// E12{ +// C0: E6{B0: x0, B1: x1, B2: x2}, +// C1: E6{B0: x3, B1: x4, B2: 0}, +// } +// +// and +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: z3, B1: z4, B2: 0}, +// } +func (e *Ext12) Mul01234By034(x *[5]E2, z3, z4 *E2) *E12 { + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: *e.Ext2.Zero()} + a := e.Ext6.Add(e.Ext6.One(), &E6{B0: *z3, B1: *z4, B2: *e.Ext2.Zero()}) + b := e.Ext6.Add(c0, c1) + a = e.Ext6.Mul(a, b) + c := e.Ext6.Mul01By01(z3, z4, &x[3], &x[4]) + z1 := e.Ext6.Sub(a, c0) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, c0) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +// Torus-based arithmetic: +// +// After the easy part of the final exponentiation the elements are in a proper +// subgroup of Fpk (E12) that coincides with some algebraic tori. The elements +// are in the torus Tk(Fp) and thus in each torus Tk/d(Fp^d) for d|k, d≠k. We +// take d=6. So the elements are in T2(Fp6). +// Let G_{q,2} = {m ∈ Fq^2 | m^(q+1) = 1} where q = p^6. +// When m.C1 = 0, then m.C0 must be 1 or −1. +// +// We recall the tower construction: +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-9-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v + +// CompressTorus compresses x ∈ E12 to (x.C0 + 1)/x.C1 ∈ E6 +func (e Ext12) CompressTorus(x *E12) *E6 { + // x ∈ G_{q,2} \ {-1,1} + y := e.Ext6.Add(&x.C0, e.Ext6.One()) + y = e.Ext6.DivUnchecked(y, &x.C1) + return y +} + +// DecompressTorus decompresses y ∈ E6 to (y+w)/(y-w) ∈ E12 +func (e Ext12) DecompressTorus(y *E6) *E12 { + var n, d E12 + one := e.Ext6.One() + n.C0 = *y + n.C1 = *one + d.C0 = *y + d.C1 = *e.Ext6.Neg(one) + + x := e.DivUnchecked(&n, &d) + return x +} + +// MulTorus multiplies two compressed elements y1, y2 ∈ E6 +// and returns (y1 * y2 + v)/(y1 + y2) +// N.B.: we use MulTorus in the final exponentiation throughout y1 ≠ -y2 always. +func (e Ext12) MulTorus(y1, y2 *E6) *E6 { + n := e.Ext6.Mul(y1, y2) + n.B1 = *e.Ext2.Add(&n.B1, e.Ext2.One()) + d := e.Ext6.Add(y1, y2) + y3 := e.Ext6.DivUnchecked(n, d) + return y3 +} + +// InverseTorus inverses a compressed elements y ∈ E6 +// and returns -y +func (e Ext12) InverseTorus(y *E6) *E6 { + return e.Ext6.Neg(y) +} + +// SquareTorus squares a compressed elements y ∈ E6 +// and returns (y + v/y)/2 +// +// It uses a hint to verify that (2x-y)y = v saving one E6 AssertIsEqual. +func (e Ext12) SquareTorus(y *E6) *E6 { + res, err := e.fp.NewHint(squareTorusHint, 6, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + sq := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // v = (2x-y)y + v := e.Ext6.Double(&sq) + v = e.Ext6.Sub(v, y) + v = e.Ext6.Mul(v, y) + + _v := E6{B0: *e.Ext2.Zero(), B1: *e.Ext2.One(), B2: *e.Ext2.Zero()} + e.Ext6.AssertIsEqual(v, &_v) + + return &sq + +} + +// FrobeniusTorus raises a compressed elements y ∈ E6 to the modulus p +// and returns y^p / v^((p-1)/2) +func (e Ext12) FrobeniusTorus(y *E6) *E6 { + t0 := e.Ext2.Conjugate(&y.B0) + t1 := e.Ext2.Conjugate(&y.B1) + t2 := e.Ext2.Conjugate(&y.B2) + t1 = e.Ext2.MulByNonResidue1Power2(t1) + t2 = e.Ext2.MulByNonResidue1Power4(t2) + + v0 := E2{emulated.ValueOf[emulated.BN254Fp]("18566938241244942414004596690298913868373833782006617400804628704885040364344"), emulated.ValueOf[emulated.BN254Fp]("5722266937896532885780051958958348231143373700109372999374820235121374419868")} + res := &E6{B0: *t0, B1: *t1, B2: *t2} + res = e.Ext6.MulBy0(res, &v0) + + return res +} + +// FrobeniusSquareTorus raises a compressed elements y ∈ E6 to the square modulus p^2 +// and returns y^(p^2) / v^((p^2-1)/2) +func (e Ext12) FrobeniusSquareTorus(y *E6) *E6 { + v0 := emulated.ValueOf[emulated.BN254Fp]("2203960485148121921418603742825762020974279258880205651967") + t0 := e.Ext2.MulByElement(&y.B0, &v0) + t1 := e.Ext2.MulByNonResidue2Power2(&y.B1) + t1 = e.Ext2.MulByElement(t1, &v0) + t2 := e.Ext2.MulByNonResidue2Power4(&y.B2) + t2 = e.Ext2.MulByElement(t2, &v0) + + return &E6{B0: *t0, B1: *t1, B2: *t2} +} + +// FrobeniusCubeTorus raises a compressed elements y ∈ E6 to the cube modulus p^3 +// and returns y^(p^3) / v^((p^3-1)/2) +func (e Ext12) FrobeniusCubeTorus(y *E6) *E6 { + t0 := e.Ext2.Conjugate(&y.B0) + t1 := e.Ext2.Conjugate(&y.B1) + t2 := e.Ext2.Conjugate(&y.B2) + t1 = e.Ext2.MulByNonResidue3Power2(t1) + t2 = e.Ext2.MulByNonResidue3Power4(t2) + + v0 := E2{emulated.ValueOf[emulated.BN254Fp]("10190819375481120917420622822672549775783927716138318623895010788866272024264"), emulated.ValueOf[emulated.BN254Fp]("303847389135065887422783454877609941456349188919719272345083954437860409601")} + res := &E6{B0: *t0, B1: *t1, B2: *t2} + res = e.Ext6.MulBy0(res, &v0) + + return res +} diff --git a/std/algebra/emulated/fields_bn254/e12_test.go b/std/algebra/emulated/fields_bn254/e12_test.go new file mode 100644 index 0000000000..a3289b4698 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e12_test.go @@ -0,0 +1,620 @@ +package fields_bn254 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e12Add struct { + A, B, C E12 +} + +func (circuit *e12Add) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e12Add{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Sub struct { + A, B, C E12 +} + +func (circuit *e12Sub) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e12Sub{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Mul struct { + A, B, C E12 +} + +func (circuit *e12Mul) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e12Mul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Div struct { + A, B, C E12 +} + +func (circuit *e12Div) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e12Div{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Square struct { + A, C E12 +} + +func (circuit *e12Square) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e12Square{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Conjugate struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Conjugate) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e12Conjugate{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12Inverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Inverse) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e12Inverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12ExptTorus struct { + A E6 + C E12 `gnark:",public"` +} + +func (circuit *e12ExptTorus) Define(api frontend.API) error { + e := NewExt12(api) + z := e.ExptTorus(&circuit.A) + expected := e.DecompressTorus(z) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestFp12ExptTorus(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bn254.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c.Expt(&a) + _a, _ := a.CompressTorus() + witness := e12ExptTorus{ + A: FromE6(&_a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12ExptTorus{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12MulBy034 struct { + A E12 `gnark:",public"` + W E12 + B, C E2 +} + +func (circuit *e12MulBy034) Define(api frontend.API) error { + e := NewExt12(api) + res := e.MulBy034(&circuit.A, &circuit.B, &circuit.C) + e.AssertIsEqual(res, &circuit.W) + return nil +} + +func TestFp12MulBy034(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, w bn254.E12 + _, _ = a.SetRandom() + var one, b, c bn254.E2 + one.SetOne() + _, _ = b.SetRandom() + _, _ = c.SetRandom() + w.Set(&a) + w.MulBy034(&one, &b, &c) + + witness := e12MulBy034{ + A: FromE12(&a), + B: FromE2(&b), + C: FromE2(&c), + W: FromE12(&w), + } + + err := test.IsSolved(&e12MulBy034{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +// Torus-based arithmetic +type torusCompress struct { + A E12 + C E6 `gnark:",public"` +} + +func (circuit *torusCompress) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.CompressTorus(&circuit.A) + e.Ext6.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusCompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bn254.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c, _ := a.CompressTorus() + + witness := torusCompress{ + A: FromE12(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&torusCompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusDecompress struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusDecompress) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusDecompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bn254.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + d, _ := a.CompressTorus() + c := d.DecompressTorus() + + witness := torusDecompress{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusDecompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusMul struct { + A E12 + B E12 + C E12 `gnark:",public"` +} + +func (circuit *torusMul) Define(api frontend.API) error { + e := NewExt12(api) + compressedA := e.CompressTorus(&circuit.A) + compressedB := e.CompressTorus(&circuit.B) + compressedAB := e.MulTorus(compressedA, compressedB) + expected := e.DecompressTorus(compressedAB) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusMul(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c, tmp bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + // put b in the cyclotomic subgroup + tmp.Conjugate(&b) + b.Inverse(&b) + tmp.Mul(&tmp, &b) + b.FrobeniusSquare(&tmp).Mul(&b, &tmp) + + // uncompressed mul + c.Mul(&a, &b) + + witness := torusMul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&torusMul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusInverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusInverse) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.InverseTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusInverse(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed inverse + c.Inverse(&a) + + witness := torusInverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusInverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobenius struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobenius) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobenius(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobenius + c.Frobenius(&a) + + witness := torusFrobenius{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobenius{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobeniusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobeniusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusSquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobeniusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobeniusSquare + c.FrobeniusSquare(&a) + + witness := torusFrobeniusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobeniusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobeniusCube struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobeniusCube) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusCubeTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobeniusCube(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobeniusCube + c.FrobeniusCube(&a) + + witness := torusFrobeniusCube{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobeniusCube{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.SquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed square + c.Square(&a) + + witness := torusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bn254/e2.go b/std/algebra/emulated/fields_bn254/e2.go new file mode 100644 index 0000000000..161712f197 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e2.go @@ -0,0 +1,353 @@ +package fields_bn254 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +type curveF = emulated.Field[emulated.BN254Fp] +type baseEl = emulated.Element[emulated.BN254Fp] + +type E2 struct { + A0, A1 baseEl +} + +type Ext2 struct { + api frontend.API + fp *curveF + nonResidues map[int]map[int]*E2 +} + +func NewExt2(api frontend.API) *Ext2 { + fp, err := emulated.NewField[emulated.BN254Fp](api) + if err != nil { + // TODO: we start returning errors when generifying + panic(err) + } + pwrs := map[int]map[int]struct { + A0 string + A1 string + }{ + 1: { + 1: {"8376118865763821496583973867626364092589906065868298776909617916018768340080", "16469823323077808223889137241176536799009286646108169935659301613961712198316"}, + 2: {"21575463638280843010398324269430826099269044274347216827212613867836435027261", "10307601595873709700152284273816112264069230130616436755625194854815875713954"}, + 3: {"2821565182194536844548159561693502659359617185244120367078079554186484126554", "3505843767911556378687030309984248845540243509899259641013678093033130930403"}, + 4: {"2581911344467009335267311115468803099551665605076196740867805258568234346338", "19937756971775647987995932169929341994314640652964949448313374472400716661030"}, + 5: {"685108087231508774477564247770172212460312782337200605669322048753928464687", "8447204650696766136447902020341177575205426561248465145919723016860428151883"}, + }, + 3: { + 1: {"11697423496358154304825782922584725312912383441159505038794027105778954184319", "303847389135065887422783454877609941456349188919719272345083954437860409601"}, + 2: {"3772000881919853776433695186713858239009073593817195771773381919316419345261", "2236595495967245188281701248203181795121068902605861227855261137820944008926"}, + 3: {"19066677689644738377698246183563772429336693972053703295610958340458742082029", "18382399103927718843559375435273026243156067647398564021675359801612095278180"}, + 4: {"5324479202449903542726783395506214481928257762400643279780343368557297135718", "16208900380737693084919495127334387981393726419856888799917914180988844123039"}, + 5: {"8941241848238582420466759817324047081148088512956452953208002715982955420483", "10338197737521362862238855242243140895517409139741313354160881284257516364953"}, + }, + } + nonResidues := make(map[int]map[int]*E2) + for pwr, v := range pwrs { + for coeff, v := range v { + el := E2{emulated.ValueOf[emulated.BN254Fp](v.A0), emulated.ValueOf[emulated.BN254Fp](v.A1)} + if nonResidues[pwr] == nil { + nonResidues[pwr] = make(map[int]*E2) + } + nonResidues[pwr][coeff] = &el + } + } + return &Ext2{api: api, fp: fp, nonResidues: nonResidues} +} + +func (e Ext2) MulByElement(x *E2, y *baseEl) *E2 { + z0 := e.fp.MulMod(&x.A0, y) + z1 := e.fp.MulMod(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) MulByConstElement(x *E2, y *big.Int) *E2 { + z0 := e.fp.MulConst(&x.A0, y) + z1 := e.fp.MulConst(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Conjugate(x *E2) *E2 { + z0 := x.A0 + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: z0, + A1: *z1, + } +} + +func (e Ext2) MulByNonResidueGeneric(x *E2, power, coef int) *E2 { + y := e.nonResidues[power][coef] + z := e.Mul(x, y) + return z +} + +// MulByNonResidue return x*(9+u) +func (e Ext2) MulByNonResidue(x *E2) *E2 { + nine := big.NewInt(9) + a := e.fp.MulConst(&x.A0, nine) + a = e.fp.Sub(a, &x.A1) + b := e.fp.MulConst(&x.A1, nine) + b = e.fp.Add(b, &x.A0) + return &E2{ + A0: *a, + A1: *b, + } +} + +// MulByNonResidue1Power1 returns x*(9+u)^(1*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power1(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 1) +} + +// MulByNonResidue1Power2 returns x*(9+u)^(2*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power2(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 2) +} + +// MulByNonResidue1Power3 returns x*(9+u)^(3*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power3(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 3) +} + +// MulByNonResidue1Power4 returns x*(9+u)^(4*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power4(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 4) +} + +// MulByNonResidue1Power5 returns x*(9+u)^(5*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power5(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 5) +} + +// MulByNonResidue2Power1 returns x*(9+u)^(1*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power1(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("21888242871839275220042445260109153167277707414472061641714758635765020556617") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power2 returns x*(9+u)^(2*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power2(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("21888242871839275220042445260109153167277707414472061641714758635765020556616") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power3 returns x*(9+u)^(3*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power3(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("21888242871839275222246405745257275088696311157297823662689037894645226208582") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power4 returns x*(9+u)^(4*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power4(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("2203960485148121921418603742825762020974279258880205651966") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power5 returns x*(9+u)^(5*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power5(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("2203960485148121921418603742825762020974279258880205651967") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue3Power1 returns x*(9+u)^(1*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power1(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 1) +} + +// MulByNonResidue3Power2 returns x*(9+u)^(2*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power2(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 2) +} + +// MulByNonResidue3Power3 returns x*(9+u)^(3*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power3(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 3) +} + +// MulByNonResidue3Power4 returns x*(9+u)^(4*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power4(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 4) +} + +// MulByNonResidue3Power5 returns x*(9+u)^(5*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power5(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 5) +} + +func (e Ext2) Mul(x, y *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Add(&y.A0, &y.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &y.A0) + c := e.fp.MulMod(&x.A1, &y.A1) + z1 := e.fp.Sub(a, b) + z1 = e.fp.Sub(z1, c) + z0 := e.fp.Sub(b, c) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Add(x, y *E2) *E2 { + z0 := e.fp.Add(&x.A0, &y.A0) + z1 := e.fp.Add(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Sub(x, y *E2) *E2 { + z0 := e.fp.Sub(&x.A0, &y.A0) + z1 := e.fp.Sub(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Neg(x *E2) *E2 { + z0 := e.fp.Neg(&x.A0) + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) One() *E2 { + z0 := e.fp.One() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Zero() *E2 { + z0 := e.fp.Zero() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) IsZero(z *E2) frontend.Variable { + a0 := e.fp.IsZero(&z.A0) + a1 := e.fp.IsZero(&z.A1) + return e.api.And(a0, a1) +} + +func (e Ext2) Square(x *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Sub(&x.A0, &x.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &x.A1) + b = e.fp.MulConst(b, big.NewInt(2)) + return &E2{ + A0: *a, + A1: *b, + } +} + +func (e Ext2) Double(x *E2) *E2 { + two := big.NewInt(2) + z0 := e.fp.MulConst(&x.A0, two) + z1 := e.fp.MulConst(&x.A1, two) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) AssertIsEqual(x, y *E2) { + e.fp.AssertIsEqual(&x.A0, &y.A0) + e.fp.AssertIsEqual(&x.A1, &y.A1) +} + +func FromE2(y *bn254.E2) E2 { + return E2{ + A0: emulated.ValueOf[emulated.BN254Fp](y.A0), + A1: emulated.ValueOf[emulated.BN254Fp](y.A1), + } +} + +func (e Ext2) Inverse(x *E2) *E2 { + res, err := e.fp.NewHint(inverseE2Hint, 2, &x.A0, &x.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E2{ + A0: *res[0], + A1: *res[1], + } + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext2) DivUnchecked(x, y *E2) *E2 { + res, err := e.fp.NewHint(divE2Hint, 2, &x.A0, &x.A1, &y.A0, &y.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E2{ + A0: *res[0], + A1: *res[1], + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext2) Select(selector frontend.Variable, z1, z0 *E2) *E2 { + a0 := e.fp.Select(selector, &z1.A0, &z0.A0) + a1 := e.fp.Select(selector, &z1.A1, &z0.A1) + return &E2{A0: *a0, A1: *a1} +} + +func (e Ext2) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E2) *E2 { + a0 := e.fp.Lookup2(s1, s2, &a.A0, &b.A0, &c.A0, &d.A0) + a1 := e.fp.Lookup2(s1, s2, &a.A1, &b.A1, &c.A1, &d.A1) + return &E2{A0: *a0, A1: *a1} +} diff --git a/std/algebra/emulated/fields_bn254/e2_test.go b/std/algebra/emulated/fields_bn254/e2_test.go new file mode 100644 index 0000000000..55c2564b02 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e2_test.go @@ -0,0 +1,351 @@ +package fields_bn254 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +type e2Add struct { + A, B, C E2 +} + +func (circuit *e2Add) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e2Add{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Sub struct { + A, B, C E2 +} + +func (circuit *e2Sub) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e2Sub{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Double struct { + A, C E2 +} + +func (circuit *e2Double) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Double(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDoubleFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Double(&a) + + witness := e2Double{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Double{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Mul struct { + A, B, C E2 +} + +func (circuit *e2Mul) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e2Mul{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Square struct { + A, C E2 +} + +func (circuit *e2Square) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e2Square{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Div struct { + A, B, C E2 +} + +func (circuit *e2Div) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e2Div{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByElement struct { + A E2 + B baseEl + C E2 `gnark:",public"` +} + +func (circuit *e2MulByElement) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByElement(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulByElement(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + var b fp.Element + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByElement(&a, &b) + + witness := e2MulByElement{ + A: FromE2(&a), + B: emulated.ValueOf[emulated.BN254Fp](b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByElement{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByNonResidue struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2MulByNonResidue) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp2ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e2MulByNonResidue{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Neg struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Neg) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e2Neg{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Conjugate struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Conjugate) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e2Conjugate{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Inverse struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Inverse) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e2Inverse{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bn254/e6.go b/std/algebra/emulated/fields_bn254/e6.go new file mode 100644 index 0000000000..4330fa49fb --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e6.go @@ -0,0 +1,338 @@ +package fields_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" +) + +type E6 struct { + B0, B1, B2 E2 +} + +type Ext6 struct { + *Ext2 +} + +func NewExt6(api frontend.API) *Ext6 { + return &Ext6{Ext2: NewExt2(api)} +} + +func (e Ext6) One() *E6 { + z0 := e.Ext2.One() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Zero() *E6 { + z0 := e.Ext2.Zero() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) IsZero(z *E6) frontend.Variable { + b0 := e.Ext2.IsZero(&z.B0) + b1 := e.Ext2.IsZero(&z.B1) + b2 := e.Ext2.IsZero(&z.B2) + return e.api.And(e.api.And(b0, b1), b2) +} + +func (e Ext6) Add(x, y *E6) *E6 { + z0 := e.Ext2.Add(&x.B0, &y.B0) + z1 := e.Ext2.Add(&x.B1, &y.B1) + z2 := e.Ext2.Add(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Neg(x *E6) *E6 { + z0 := e.Ext2.Neg(&x.B0) + z1 := e.Ext2.Neg(&x.B1) + z2 := e.Ext2.Neg(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Sub(x, y *E6) *E6 { + z0 := e.Ext2.Sub(&x.B0, &y.B0) + z1 := e.Ext2.Sub(&x.B1, &y.B1) + z2 := e.Ext2.Sub(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Mul(x, y *E6) *E6 { + t0 := e.Ext2.Mul(&x.B0, &y.B0) + t1 := e.Ext2.Mul(&x.B1, &y.B1) + t2 := e.Ext2.Mul(&x.B2, &y.B2) + c0 := e.Ext2.Add(&x.B1, &x.B2) + tmp := e.Ext2.Add(&y.B1, &y.B2) + c0 = e.Ext2.Mul(c0, tmp) + c0 = e.Ext2.Sub(c0, t1) + c0 = e.Ext2.Sub(c0, t2) + c0 = e.Ext2.MulByNonResidue(c0) + c0 = e.Ext2.Add(c0, t0) + c1 := e.Ext2.Add(&x.B0, &x.B1) + tmp = e.Ext2.Add(&y.B0, &y.B1) + c1 = e.Ext2.Mul(c1, tmp) + c1 = e.Ext2.Sub(c1, t0) + c1 = e.Ext2.Sub(c1, t1) + tmp = e.Ext2.MulByNonResidue(t2) + c1 = e.Ext2.Add(c1, tmp) + tmp = e.Ext2.Add(&x.B0, &x.B2) + c2 := e.Ext2.Add(&y.B0, &y.B2) + c2 = e.Ext2.Mul(c2, tmp) + c2 = e.Ext2.Sub(c2, t0) + c2 = e.Ext2.Sub(c2, t2) + c2 = e.Ext2.Add(c2, t1) + return &E6{ + B0: *c0, + B1: *c1, + B2: *c2, + } +} + +func (e Ext6) Double(x *E6) *E6 { + z0 := e.Ext2.Double(&x.B0) + z1 := e.Ext2.Double(&x.B1) + z2 := e.Ext2.Double(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Square(x *E6) *E6 { + c4 := e.Ext2.Mul(&x.B0, &x.B1) + c4 = e.Ext2.Double(c4) + c5 := e.Ext2.Square(&x.B2) + c1 := e.Ext2.MulByNonResidue(c5) + c1 = e.Ext2.Add(c1, c4) + c2 := e.Ext2.Sub(c4, c5) + c3 := e.Ext2.Square(&x.B0) + c4 = e.Ext2.Sub(&x.B0, &x.B1) + c4 = e.Ext2.Add(c4, &x.B2) + c5 = e.Ext2.Mul(&x.B1, &x.B2) + c5 = e.Ext2.Double(c5) + c4 = e.Ext2.Square(c4) + c0 := e.Ext2.MulByNonResidue(c5) + c0 = e.Ext2.Add(c0, c3) + z2 := e.Ext2.Add(c2, c4) + z2 = e.Ext2.Add(z2, c5) + z2 = e.Ext2.Sub(z2, c3) + z0 := c0 + z1 := c1 + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) MulByE2(x *E6, y *E2) *E6 { + z0 := e.Ext2.Mul(&x.B0, y) + z1 := e.Ext2.Mul(&x.B1, y) + z2 := e.Ext2.Mul(&x.B2, y) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +// MulBy0 multiplies z by an E6 sparse element of the form +// +// E6{ +// B0: c0, +// B1: 0, +// B2: 0, +// } +func (e Ext6) MulBy0(z *E6, c0 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + tmp := e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 := e.Ext2.Mul(c0, tmp) + t1 = e.Ext2.Sub(t1, a) + return &E6{ + B0: *a, + B1: *t1, + B2: *t2, + } +} + +// MulBy01 multiplies z by an E6 sparse element of the form +// +// E6{ +// B0: c0, +// B1: c1, +// B2: 0, +// } +func (e Ext6) MulBy01(z *E6, c0, c1 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + b := e.Ext2.Mul(&z.B1, c1) + tmp := e.Ext2.Add(&z.B1, &z.B2) + t0 := e.Ext2.Mul(c1, tmp) + t0 = e.Ext2.Sub(t0, b) + t0 = e.Ext2.MulByNonResidue(t0) + t0 = e.Ext2.Add(t0, a) + tmp = e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + t2 = e.Ext2.Add(t2, b) + t1 := e.Ext2.Add(c0, c1) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 = e.Ext2.Mul(t1, tmp) + t1 = e.Ext2.Sub(t1, a) + t1 = e.Ext2.Sub(t1, b) + return &E6{ + B0: *t0, + B1: *t1, + B2: *t2, + } +} + +// Mul01By01 multiplies two E6 sparse element of the form: +// +// E6{ +// B0: c0, +// B1: c1, +// B2: 0, +// } +// +// and +// +// E6{ +// B0: d0, +// B1: d1, +// B2: 0, +// } +func (e Ext6) Mul01By01(c0, c1, d0, d1 *E2) *E6 { + a := e.Ext2.Mul(d0, c0) + b := e.Ext2.Mul(d1, c1) + t0 := e.Ext2.Mul(c1, d1) + t0 = e.Ext2.Sub(t0, b) + t0 = e.Ext2.MulByNonResidue(t0) + t0 = e.Ext2.Add(t0, a) + t2 := e.Ext2.Mul(c0, d0) + t2 = e.Ext2.Sub(t2, a) + t2 = e.Ext2.Add(t2, b) + t1 := e.Ext2.Add(c0, c1) + tmp := e.Ext2.Add(d0, d1) + t1 = e.Ext2.Mul(t1, tmp) + t1 = e.Ext2.Sub(t1, a) + t1 = e.Ext2.Sub(t1, b) + return &E6{ + B0: *t0, + B1: *t1, + B2: *t2, + } +} + +func (e Ext6) MulByNonResidue(x *E6) *E6 { + z2, z1, z0 := &x.B1, &x.B0, &x.B2 + z0 = e.Ext2.MulByNonResidue(z0) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) FrobeniusSquare(x *E6) *E6 { + z01 := e.Ext2.MulByNonResidue2Power2(&x.B1) + z02 := e.Ext2.MulByNonResidue2Power4(&x.B2) + return &E6{B0: x.B0, B1: *z01, B2: *z02} +} + +func (e Ext6) AssertIsEqual(x, y *E6) { + e.Ext2.AssertIsEqual(&x.B0, &y.B0) + e.Ext2.AssertIsEqual(&x.B1, &y.B1) + e.Ext2.AssertIsEqual(&x.B2, &y.B2) +} + +func FromE6(y *bn254.E6) E6 { + return E6{ + B0: FromE2(&y.B0), + B1: FromE2(&y.B1), + B2: FromE2(&y.B2), + } + +} + +func (e Ext6) Inverse(x *E6) *E6 { + res, err := e.fp.NewHint(inverseE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext6) DivUnchecked(x, y *E6) *E6 { + res, err := e.fp.NewHint(divE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext6) Select(selector frontend.Variable, z1, z0 *E6) *E6 { + b0 := e.Ext2.Select(selector, &z1.B0, &z0.B0) + b1 := e.Ext2.Select(selector, &z1.B1, &z0.B1) + b2 := e.Ext2.Select(selector, &z1.B2, &z0.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} + +func (e Ext6) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E6) *E6 { + b0 := e.Ext2.Lookup2(s1, s2, &a.B0, &b.B0, &c.B0, &d.B0) + b1 := e.Ext2.Lookup2(s1, s2, &a.B1, &b.B1, &c.B1, &d.B1) + b2 := e.Ext2.Lookup2(s1, s2, &a.B2, &b.B2, &c.B2, &d.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} diff --git a/std/algebra/emulated/fields_bn254/e6_test.go b/std/algebra/emulated/fields_bn254/e6_test.go new file mode 100644 index 0000000000..6c46bb7d46 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e6_test.go @@ -0,0 +1,363 @@ +package fields_bn254 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e6Add struct { + A, B, C E6 +} + +func (circuit *e6Add) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e6Add{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Sub struct { + A, B, C E6 +} + +func (circuit *e6Sub) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e6Sub{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Mul struct { + A, B, C E6 +} + +func (circuit *e6Mul) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e6Mul{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Square struct { + A, C E6 +} + +func (circuit *e6Square) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e6Square{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Div struct { + A, B, C E6 +} + +func (circuit *e6Div) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e6Div{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByNonResidue struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByNonResidue) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e6MulByNonResidue{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByE2 struct { + A E6 + B E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByE2) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByE2(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByE2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + var b bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByE2(&a, &b) + + witness := e6MulByE2{ + A: FromE6(&a), + B: FromE2(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByE2{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulBy01 struct { + A E6 + C0, C1 E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulBy01) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulBy01(&circuit.A, &circuit.C0, &circuit.C1) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6By01(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + var C0, C1 bn254.E2 + _, _ = a.SetRandom() + _, _ = C0.SetRandom() + _, _ = C1.SetRandom() + c.Set(&a) + c.MulBy01(&C0, &C1) + + witness := e6MulBy01{ + A: FromE6(&a), + C0: FromE2(&C0), + C1: FromE2(&C1), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulBy01{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulBy0 struct { + A E6 + C0 E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulBy0) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulBy0(&circuit.A, &circuit.C0) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6By0(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + var C0, zero bn254.E2 + _, _ = a.SetRandom() + _, _ = C0.SetRandom() + c.Set(&a) + c.MulBy01(&C0, &zero) + + witness := e6MulBy0{ + A: FromE6(&a), + C0: FromE2(&C0), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulBy0{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Neg struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Neg) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e6Neg{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e6Inverse struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Inverse) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e6Inverse{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bn254/hints.go b/std/algebra/emulated/fields_bn254/hints.go new file mode 100644 index 0000000000..b08d6ed977 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/hints.go @@ -0,0 +1,238 @@ +package fields_bn254 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/math/emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + // E2 + divE2Hint, + inverseE2Hint, + // E6 + divE6Hint, + inverseE6Hint, + squareTorusHint, + // E12 + divE12Hint, + inverseE12Hint, + } +} + +func inverseE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + + c.Inverse(&a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +func divE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bn254.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + b.A0.SetBigInt(inputs[2]) + b.A1.SetBigInt(inputs[3]) + + c.Inverse(&b).Mul(&c, &a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +// E6 hints +func inverseE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + c.Inverse(&a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func divE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bn254.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + b.B0.A0.SetBigInt(inputs[6]) + b.B0.A1.SetBigInt(inputs[7]) + b.B1.A0.SetBigInt(inputs[8]) + b.B1.A1.SetBigInt(inputs[9]) + b.B2.A0.SetBigInt(inputs[10]) + b.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&b).Mul(&c, &a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func squareTorusHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + _c := a.DecompressTorus() + _c.CyclotomicSquare(&_c) + c, _ = _c.CompressTorus() + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +// E12 hints +func inverseE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} + +func divE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bn254.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + b.C0.B0.A0.SetBigInt(inputs[12]) + b.C0.B0.A1.SetBigInt(inputs[13]) + b.C0.B1.A0.SetBigInt(inputs[14]) + b.C0.B1.A1.SetBigInt(inputs[15]) + b.C0.B2.A0.SetBigInt(inputs[16]) + b.C0.B2.A1.SetBigInt(inputs[17]) + b.C1.B0.A0.SetBigInt(inputs[18]) + b.C1.B0.A1.SetBigInt(inputs[19]) + b.C1.B1.A0.SetBigInt(inputs[20]) + b.C1.B1.A1.SetBigInt(inputs[21]) + b.C1.B2.A0.SetBigInt(inputs[22]) + b.C1.B2.A1.SetBigInt(inputs[23]) + + c.Inverse(&b).Mul(&c, &a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} diff --git a/std/algebra/emulated/sw_bls12381/doc.go b/std/algebra/emulated/sw_bls12381/doc.go new file mode 100644 index 0000000000..948f61e24f --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/doc.go @@ -0,0 +1,6 @@ +// Package sw_bls12381 implements G1 and G2 arithmetics and pairing computation over BLS12-381 curve. +// +// The implementation follows [Housni22]: "Pairings in Rank-1 Constraint Systems". +// +// [Housni22]: https://eprint.iacr.org/2022/1162 +package sw_bls12381 diff --git a/std/algebra/emulated/sw_bls12381/doc_test.go b/std/algebra/emulated/sw_bls12381/doc_test.go new file mode 100644 index 0000000000..a1ce0f5ca5 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/doc_test.go @@ -0,0 +1,105 @@ +package sw_bls12381_test + +import ( + "crypto/rand" + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" +) + +type PairCircuit struct { + InG1 sw_bls12381.G1Affine + InG2 sw_bls12381.G2Affine + Res sw_bls12381.GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := sw_bls12381.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + // Pair method does not check that the points are in the proper groups. + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + // Compute the pairing + res, err := pairing.Pair([]*sw_bls12381.G1Affine{&c.InG1}, []*sw_bls12381.G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func ExamplePairing() { + p, q, err := randomG1G2Affines() + if err != nil { + panic(err) + } + res, err := bls12381.Pair([]bls12381.G1Affine{p}, []bls12381.G2Affine{q}) + if err != nil { + panic(err) + } + circuit := PairCircuit{} + witness := PairCircuit{ + InG1: sw_bls12381.NewG1Affine(p), + InG2: sw_bls12381.NewG2Affine(q), + Res: sw_bls12381.NewGTEl(res), + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } +} + +func randomG1G2Affines() (p bls12381.G1Affine, q bls12381.G2Affine, err error) { + _, _, G1AffGen, G2AffGen := bls12381.Generators() + mod := bls12381.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + p.ScalarMultiplication(&G1AffGen, s1) + q.ScalarMultiplication(&G2AffGen, s2) + return +} diff --git a/std/algebra/emulated/sw_bls12381/g1.go b/std/algebra/emulated/sw_bls12381/g1.go new file mode 100644 index 0000000000..f7908b5ded --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/g1.go @@ -0,0 +1,45 @@ +package sw_bls12381 + +import ( + "fmt" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type G1Affine = sw_emulated.AffinePoint[emulated.BLS12381Fp] + +func NewG1Affine(v bls12381.G1Affine) G1Affine { + return G1Affine{ + X: emulated.ValueOf[emulated.BLS12381Fp](v.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](v.Y), + } +} + +type G1 struct { + curveF *emulated.Field[emulated.BLS12381Fp] + w *emulated.Element[emulated.BLS12381Fp] +} + +func NewG1(api frontend.API) (*G1, error) { + ba, err := emulated.NewField[emulated.BLS12381Fp](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + w := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + return &G1{ + curveF: ba, + w: &w, + }, nil +} + +func (g1 *G1) phi(q *G1Affine) *G1Affine { + x := g1.curveF.Mul(&q.X, g1.w) + + return &G1Affine{ + X: *x, + Y: q.Y, + } +} diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go new file mode 100644 index 0000000000..45077a6cba --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -0,0 +1,219 @@ +package sw_bls12381 + +import ( + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/math/emulated" +) + +type G2 struct { + *fields_bls12381.Ext2 + u1, w *emulated.Element[emulated.BLS12381Fp] + v *fields_bls12381.E2 +} + +type G2Affine struct { + X, Y fields_bls12381.E2 +} + +func NewG2(api frontend.API) *G2 { + w := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + u1 := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + v := fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp]("2973677408986561043442465346520108879172042883009249989176415018091420807192182638567116318576472649347015917690530"), + A1: emulated.ValueOf[emulated.BLS12381Fp]("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257"), + } + return &G2{ + Ext2: fields_bls12381.NewExt2(api), + w: &w, + u1: &u1, + v: &v, + } +} + +func NewG2Affine(v bls12381.G2Affine) G2Affine { + return G2Affine{ + X: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.X.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.X.A1), + }, + Y: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.Y.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.Y.A1), + }, + } +} + +func (g2 *G2) psi(q *G2Affine) *G2Affine { + x := g2.Ext2.MulByElement(&q.X, g2.u1) + y := g2.Ext2.Conjugate(&q.Y) + y = g2.Ext2.Mul(y, g2.v) + + return &G2Affine{ + X: fields_bls12381.E2{A0: x.A1, A1: x.A0}, + Y: *y, + } +} + +func (g2 *G2) scalarMulBySeed(q *G2Affine) *G2Affine { + + z := g2.triple(q) + z = g2.double(z) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 2) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 8) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 31) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 16) + + return g2.neg(z) +} + +func (g2 G2) add(p, q *G2Affine) *G2Affine { + // compute λ = (q.y-p.y)/(q.x-p.x) + qypy := g2.Ext2.Sub(&q.Y, &p.Y) + qxpx := g2.Ext2.Sub(&q.X, &p.X) + λ := g2.Ext2.DivUnchecked(qypy, qxpx) + + // xr = λ²-p.x-q.x + λλ := g2.Ext2.Square(λ) + qxpx = g2.Ext2.Add(&p.X, &q.X) + xr := g2.Ext2.Sub(λλ, qxpx) + + // p.y = λ(p.x-r.x) - p.y + pxrx := g2.Ext2.Sub(&p.X, xr) + λpxrx := g2.Ext2.Mul(λ, pxrx) + yr := g2.Ext2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) neg(p *G2Affine) *G2Affine { + xr := &p.X + yr := g2.Ext2.Neg(&p.Y) + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) sub(p, q *G2Affine) *G2Affine { + qNeg := g2.neg(q) + return g2.add(p, qNeg) +} + +func (g2 *G2) double(p *G2Affine) *G2Affine { + // compute λ = (3p.x²)/2*p.y + xx3a := g2.Square(&p.X) + xx3a = g2.MulByConstElement(xx3a, big.NewInt(3)) + y2 := g2.Double(&p.Y) + λ := g2.DivUnchecked(xx3a, y2) + + // xr = λ²-2p.x + x2 := g2.Double(&p.X) + λλ := g2.Square(λ) + xr := g2.Sub(λλ, x2) + + // yr = λ(p-xr) - p.y + pxrx := g2.Sub(&p.X, xr) + λpxrx := g2.Mul(λ, pxrx) + yr := g2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 *G2) doubleN(p *G2Affine, n int) *G2Affine { + pn := p + for s := 0; s < n; s++ { + pn = g2.double(pn) + } + return pn +} + +func (g2 G2) triple(p *G2Affine) *G2Affine { + + // compute λ1 = (3p.x²)/2p.y + xx := g2.Square(&p.X) + xx = g2.MulByConstElement(xx, big.NewInt(3)) + y2 := g2.Double(&p.Y) + λ1 := g2.DivUnchecked(xx, y2) + + // xr = λ1²-2p.x + x2 := g2.MulByConstElement(&p.X, big.NewInt(2)) + λ1λ1 := g2.Square(λ1) + x2 = g2.Sub(λ1λ1, x2) + + // ommit y2 computation, and + // compute λ2 = 2p.y/(x2 − p.x) − λ1. + x1x2 := g2.Sub(&p.X, x2) + λ2 := g2.DivUnchecked(y2, x1x2) + λ2 = g2.Sub(λ2, λ1) + + // xr = λ²-p.x-x2 + λ2λ2 := g2.Square(λ2) + qxrx := g2.Add(x2, &p.X) + xr := g2.Sub(λ2λ2, qxrx) + + // yr = λ(p.x-xr) - p.y + pxrx := g2.Sub(&p.X, xr) + λ2pxrx := g2.Mul(λ2, pxrx) + yr := g2.Sub(λ2pxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) doubleAndAdd(p, q *G2Affine) *G2Affine { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := g2.Ext2.Sub(&q.Y, &p.Y) + xqxp := g2.Ext2.Sub(&q.X, &p.X) + λ1 := g2.Ext2.DivUnchecked(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := g2.Ext2.Square(λ1) + xqxp = g2.Ext2.Add(&p.X, &q.X) + x2 := g2.Ext2.Sub(λ1λ1, xqxp) + + // ommit y2 computation + // compute λ2 = -λ1-2*p.y/(x2-p.x) + ypyp := g2.Ext2.Add(&p.Y, &p.Y) + x2xp := g2.Ext2.Sub(x2, &p.X) + λ2 := g2.Ext2.DivUnchecked(ypyp, x2xp) + λ2 = g2.Ext2.Add(λ1, λ2) + λ2 = g2.Ext2.Neg(λ2) + + // compute x3 =λ2²-p.x-x3 + λ2λ2 := g2.Ext2.Square(λ2) + x3 := g2.Ext2.Sub(λ2λ2, &p.X) + x3 = g2.Ext2.Sub(x3, x2) + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := g2.Ext2.Sub(&p.X, x3) + y3 = g2.Ext2.Mul(λ2, y3) + y3 = g2.Ext2.Sub(y3, &p.Y) + + return &G2Affine{ + X: *x3, + Y: *y3, + } +} + +// AssertIsEqual asserts that p and q are the same point. +func (g2 *G2) AssertIsEqual(p, q *G2Affine) { + g2.Ext2.AssertIsEqual(&p.X, &q.X) + g2.Ext2.AssertIsEqual(&p.Y, &q.Y) +} diff --git a/std/algebra/emulated/sw_bls12381/g2_test.go b/std/algebra/emulated/sw_bls12381/g2_test.go new file mode 100644 index 0000000000..76937d9623 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/g2_test.go @@ -0,0 +1,120 @@ +package sw_bls12381 + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type addG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *addG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.add(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines(assert) + _, in2 := randomG1G2Affines(assert) + var res bls12381.G2Affine + res.Add(&in1, &in2) + witness := addG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleG2Circuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *doubleG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.double(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines(assert) + var res bls12381.G2Affine + var in1Jac, resJac bls12381.G2Jac + in1Jac.FromAffine(&in1) + resJac.Double(&in1Jac) + res.FromJacobian(&resJac) + witness := doubleG2Circuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleAndAddG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *doubleAndAddG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.doubleAndAdd(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleAndAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines(assert) + _, in2 := randomG1G2Affines(assert) + var res bls12381.G2Affine + res.Double(&in1). + Add(&res, &in2) + witness := doubleAndAddG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleAndAddG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type scalarMulG2BySeedCircuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *scalarMulG2BySeedCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.scalarMulBySeed(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestScalarMulG2BySeedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines(assert) + var res bls12381.G2Affine + x0, _ := new(big.Int).SetString("15132376222941642752", 10) + res.ScalarMultiplication(&in1, x0).Neg(&res) + witness := scalarMulG2BySeedCircuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2BySeedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/sw_bls12381/pairing.go b/std/algebra/emulated/sw_bls12381/pairing.go new file mode 100644 index 0000000000..58038989c0 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/pairing.go @@ -0,0 +1,625 @@ +package sw_bls12381 + +import ( + "errors" + "fmt" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type Pairing struct { + api frontend.API + *fields_bls12381.Ext12 + curveF *emulated.Field[emulated.BLS12381Fp] + g2 *G2 + g1 *G1 + curve *sw_emulated.Curve[emulated.BLS12381Fp, emulated.BLS12381Fr] + bTwist *fields_bls12381.E2 +} + +type GTEl = fields_bls12381.E12 + +func NewGTEl(v bls12381.GT) GTEl { + return GTEl{ + C0: fields_bls12381.E6{ + B0: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B0.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B0.A1), + }, + B1: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B1.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B1.A1), + }, + B2: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B2.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B2.A1), + }, + }, + C1: fields_bls12381.E6{ + B0: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B0.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B0.A1), + }, + B1: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B1.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B1.A1), + }, + B2: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B2.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B2.A1), + }, + }, + } +} + +func NewPairing(api frontend.API) (*Pairing, error) { + ba, err := emulated.NewField[emulated.BLS12381Fp](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + curve, err := sw_emulated.New[emulated.BLS12381Fp, emulated.BLS12381Fr](api, sw_emulated.GetBLS12381Params()) + if err != nil { + return nil, fmt.Errorf("new curve: %w", err) + } + bTwist := fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp]("4"), + A1: emulated.ValueOf[emulated.BLS12381Fp]("4"), + } + g1, err := NewG1(api) + if err != nil { + return nil, fmt.Errorf("new G1 struct: %w", err) + } + return &Pairing{ + api: api, + Ext12: fields_bls12381.NewExt12(api), + curveF: ba, + curve: curve, + g1: g1, + g2: NewG2(api), + bTwist: &bTwist, + }, nil +} + +// FinalExponentiation computes the exponentiation (∏ᵢ zᵢ)ᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// we use instead +// +// d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// where s is the cofactor 3 (Hayashida et al.). +// +// This is the safe version of the method where e may be {-1,1}. If it is known +// that e ≠ {-1,1} then using the unsafe version of the method saves +// considerable amount of constraints. When called with the result of +// [MillerLoop], then current method is applicable when length of the inputs to +// Miller loop is 1. +func (pr Pairing) FinalExponentiation(e *GTEl) *GTEl { + return pr.finalExponentiation(e, false) +} + +// FinalExponentiationUnsafe computes the exponentiation (∏ᵢ zᵢ)ᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// we use instead +// +// d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// where s is the cofactor 3 (Hayashida et al.). +// +// This is the unsafe version of the method where e may NOT be {-1,1}. If e ∈ +// {-1, 1}, then there exists no valid solution to the circuit. This method is +// applicable when called with the result of [MillerLoop] method when the length +// of the inputs to Miller loop is 1. +func (pr Pairing) FinalExponentiationUnsafe(e *GTEl) *GTEl { + return pr.finalExponentiation(e, true) +} + +// finalExponentiation computes the exponentiation (∏ᵢ zᵢ)ᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// we use instead +// +// d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// where s is the cofactor 3 (Hayashida et al.). +func (pr Pairing) finalExponentiation(e *GTEl, unsafe bool) *GTEl { + + // 1. Easy part + // (p⁶-1)(p²+1) + var selector1, selector2 frontend.Variable + _dummy := pr.Ext6.One() + + if unsafe { + // The Miller loop result is ≠ {-1,1}, otherwise this means P and Q are + // linearly dependant and not from G1 and G2 respectively. + // So e ∈ G_{q,2} \ {-1,1} and hence e.C1 ≠ 0. + // Nothing to do. + } else { + // However, for a product of Miller loops (n>=2) this might happen. If this is + // the case, the result is 1 in the torus. We assign a dummy value (1) to e.C1 + // and proceed further. + selector1 = pr.Ext6.IsZero(&e.C1) + e.C1 = *pr.Ext6.Select(selector1, _dummy, &e.C1) + } + + // Torus compression absorbed: + // Raising e to (p⁶-1) is + // e^(p⁶) / e = (e.C0 - w*e.C1) / (e.C0 + w*e.C1) + // = (-e.C0/e.C1 + w) / (-e.C0/e.C1 - w) + // So the fraction -e.C0/e.C1 is already in the torus. + // This absorbs the torus compression in the easy part. + c := pr.Ext6.DivUnchecked(&e.C0, &e.C1) + c = pr.Ext6.Neg(c) + t0 := pr.FrobeniusSquareTorus(c) + c = pr.MulTorus(t0, c) + + // 2. Hard part (up to permutation) + // 3(p⁴-p²+1)/r + // Daiki Hayashida, Kenichiro Hayasaka and Tadanori Teruya + // https://eprint.iacr.org/2020/875.pdf + // performed in torus compressed form + t0 = pr.SquareTorus(c) + t1 := pr.ExptHalfTorus(t0) + t2 := pr.InverseTorus(c) + t1 = pr.MulTorus(t1, t2) + t2 = pr.ExptTorus(t1) + t1 = pr.InverseTorus(t1) + t1 = pr.MulTorus(t1, t2) + t2 = pr.ExptTorus(t1) + t1 = pr.FrobeniusTorus(t1) + t1 = pr.MulTorus(t1, t2) + c = pr.MulTorus(c, t0) + t0 = pr.ExptTorus(t1) + t2 = pr.ExptTorus(t0) + t0 = pr.FrobeniusSquareTorus(t1) + t1 = pr.InverseTorus(t1) + t1 = pr.MulTorus(t1, t2) + t1 = pr.MulTorus(t1, t0) + + var result GTEl + // MulTorus(c, t1) requires c ≠ -t1. When c = -t1, it means the + // product is 1 in the torus. + if unsafe { + // For a single pairing, this does not happen because the pairing is non-degenerate. + result = *pr.DecompressTorus(pr.MulTorus(c, t1)) + } else { + // For a product of pairings this might happen when the result is expected to be 1. + // We assign a dummy value (1) to t1 and proceed furhter. + // Finally we do a select on both edge cases: + // - Only if seletor1=0 and selector2=0, we return MulTorus(c, t1) decompressed. + // - Otherwise, we return 1. + _sum := pr.Ext6.Add(c, t1) + selector2 = pr.Ext6.IsZero(_sum) + t1 = pr.Ext6.Select(selector2, _dummy, t1) + selector := pr.api.Mul(pr.api.Sub(1, selector1), pr.api.Sub(1, selector2)) + result = *pr.Select(selector, pr.DecompressTorus(pr.MulTorus(c, t1)), pr.One()) + } + + return &result +} + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 - R0(x/y) - R1(1/y) = 0 instead of R0'*y - R1'*x - R2' = 0 This +// makes the multiplication by lines (MulBy014) and between lines (Mul014By014) +// circuit-efficient. +type lineEvaluation struct { + R0, R1 fields_bls12381.E2 +} + +// Pair calculates the reduced pairing for a set of points +// ∏ᵢ e(Pᵢ, Qᵢ). +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) Pair(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + res, err := pr.MillerLoop(P, Q) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.finalExponentiation(res, len(P) == 1) + return res, nil +} + +// PairingCheck calculates the reduced pairing for a set of points and asserts if the result is One +// ∏ᵢ e(Pᵢ, Qᵢ) =? 1 +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { + f, err := pr.Pair(P, Q) + if err != nil { + return err + + } + one := pr.One() + pr.AssertIsEqual(f, one) + + return nil +} + +func (pr Pairing) AssertIsEqual(x, y *GTEl) { + pr.Ext12.AssertIsEqual(x, y) +} + +func (pr Pairing) AssertIsOnCurve(P *G1Affine) { + pr.curve.AssertIsOnCurve(P) +} + +func (pr Pairing) AssertIsOnTwist(Q *G2Affine) { + // Twist: Y² == X³ + aX + b, where a=0 and b=4(1+u) + // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) + + // if Q=(0,0) we assign b=0 otherwise 3/(9+u), and continue + selector := pr.api.And(pr.Ext2.IsZero(&Q.X), pr.Ext2.IsZero(&Q.Y)) + + b := pr.Ext2.Select(selector, pr.Ext2.Zero(), pr.bTwist) + + left := pr.Ext2.Square(&Q.Y) + right := pr.Ext2.Square(&Q.X) + right = pr.Ext2.Mul(right, &Q.X) + right = pr.Ext2.Add(right, b) + pr.Ext2.AssertIsEqual(left, right) +} + +func (pr Pairing) AssertIsOnG1(P *G1Affine) { + // 1- Check P is on the curve + pr.AssertIsOnCurve(P) + + // 2- Check P has the right subgroup order + // TODO: add phi and scalarMulBySeedSquare to g1.go + // [x²]ϕ(P) + phiP := pr.g1.phi(P) + seedSquare := emulated.ValueOf[emulated.BLS12381Fr]("228988810152649578064853576960394133504") + // TODO: use addchain to construct a fixed-scalar ScalarMul + _P := pr.curve.ScalarMul(phiP, &seedSquare) + _P = pr.curve.Neg(_P) + + // [r]Q == 0 <==> P = -[x²]ϕ(P) + pr.curve.AssertIsEqual(_P, P) +} + +func (pr Pairing) AssertIsOnG2(Q *G2Affine) { + // 1- Check Q is on the curve + pr.AssertIsOnTwist(Q) + + // 2- Check Q has the right subgroup order + // [x₀]Q + xQ := pr.g2.scalarMulBySeed(Q) + // ψ(Q) + psiQ := pr.g2.psi(Q) + + // [r]Q == 0 <==> ψ(Q) == [x₀]Q + pr.g2.AssertIsEqual(xQ, psiQ) +} + +// loopCounter = seed in binary +// +// seed=-15132376222941642752 +var loopCounter = [64]int8{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 1, 0, 1, 1, +} + +// MillerLoop computes the multi-Miller loop +// ∏ᵢ { fᵢ_{u,Q}(P) } +func (pr Pairing) MillerLoop(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return nil, errors.New("invalid inputs sizes") + } + + res := pr.Ext12.One() + + var l1, l2 *lineEvaluation + Qacc := make([]*G2Affine, n) + yInv := make([]*emulated.Element[emulated.BLS12381Fp], n) + xOverY := make([]*emulated.Element[emulated.BLS12381Fp], n) + + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + // P and Q are supposed to be on G1 and G2 respectively of prime order r. + // The point (x,0) is of order 2. But this function does not check + // subgroup membership. + // Anyway (x,0) cannot be on BLS12-381 because -4 is a cubic non-residue in Fp. + // so, 1/y is well defined for all points P's + yInv[k] = pr.curveF.Inverse(&P[k].Y) + xOverY[k] = pr.curveF.MulMod(&P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + + // i = 62, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + + // Qacc[k] ← 3Qacc[k], + // l1 the tangent ℓ to 2Q[k] + // l2 the line ℓ passing 2Q[k] and Q[k] + Qacc[0], l1, l2 = pr.tripleStep(Qacc[0]) + // line evaluation at P[0] + // and assign line to res (R1, R0, 0, 0, 1, 0) + res.C0.B1 = *pr.MulByElement(&l1.R0, xOverY[0]) + res.C0.B0 = *pr.MulByElement(&l1.R1, yInv[0]) + res.C1.B1 = *pr.Ext2.One() + // line evaluation at P[0] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[0]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[0]) + // res = ℓ × ℓ + prodLines := *pr.Mul014By014(&l2.R1, &l2.R0, &res.C0.B0, &res.C0.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B1 = prodLines[3] + res.C1.B2 = prodLines[4] + + for k := 1; k < n; k++ { + // Qacc[k] ← 3Qacc[k], + // l1 the tangent ℓ to 2Q[k] + // l2 the line ℓ passing 2Q[k] and Q[k] + Qacc[k], l1, l2 = pr.tripleStep(Qacc[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l2.R1, &l2.R0) + } + + // Compute ∏ᵢ { fᵢ_{u,Q}(P) } + for i := 61; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + if loopCounter[i] == 0 { + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = pr.doubleStep(Qacc[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + } + } else { + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = pr.doubleAndAddStep(Qacc[k], Q[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + // ℓ × res + res = pr.MulBy014(res, &l2.R1, &l2.R0) + } + } + } + + // i = 0, separately to avoid a point doubling + res = pr.Square(res) + for k := 0; k < n; k++ { + // l1 the tangent ℓ passing 2Qacc[k] + l1 = pr.tangentCompute(Qacc[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + } + + // negative x₀ + res = pr.Ext12.Conjugate(res) + + return res, nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleAndAddStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation, *lineEvaluation) { + + var line1, line2 lineEvaluation + var p G2Affine + + // compute λ1 = (y2-y1)/(x2-x1) + n := pr.Ext2.Sub(&p1.Y, &p2.Y) + d := pr.Ext2.Sub(&p1.X, &p2.X) + l1 := pr.Ext2.DivUnchecked(n, d) + + // compute x3 =λ1²-x1-x2 + x3 := pr.Ext2.Square(l1) + x3 = pr.Ext2.Sub(x3, &p1.X) + x3 = pr.Ext2.Sub(x3, &p2.X) + + // omit y3 computation + + // compute line1 + line1.R0 = *pr.Ext2.Neg(l1) + line1.R1 = *pr.Ext2.Mul(l1, &p1.X) + line1.R1 = *pr.Ext2.Sub(&line1.R1, &p1.Y) + + // compute λ2 = -λ1-2y1/(x3-x1) + n = pr.Ext2.Double(&p1.Y) + d = pr.Ext2.Sub(x3, &p1.X) + l2 := pr.Ext2.DivUnchecked(n, d) + l2 = pr.Ext2.Add(l2, l1) + l2 = pr.Ext2.Neg(l2) + + // compute x4 = λ2²-x1-x3 + x4 := pr.Ext2.Square(l2) + x4 = pr.Ext2.Sub(x4, &p1.X) + x4 = pr.Ext2.Sub(x4, x3) + + // compute y4 = λ2(x1 - x4)-y1 + y4 := pr.Ext2.Sub(&p1.X, x4) + y4 = pr.Ext2.Mul(l2, y4) + y4 = pr.Ext2.Sub(y4, &p1.Y) + + p.X = *x4 + p.Y = *y4 + + // compute line2 + line2.R0 = *pr.Ext2.Neg(l2) + line2.R1 = *pr.Ext2.Mul(l2, &p1.X) + line2.R1 = *pr.Ext2.Sub(&line2.R1, &p1.Y) + + return &p, &line1, &line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleStep(p1 *G2Affine) (*G2Affine, *lineEvaluation) { + + var p G2Affine + var line lineEvaluation + + // λ = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ := pr.Ext2.DivUnchecked(n, d) + + // xr = λ²-2x + xr := pr.Ext2.Square(λ) + xr = pr.Ext2.Sub(xr, &p1.X) + xr = pr.Ext2.Sub(xr, &p1.X) + + // yr = λ(x-xr)-y + yr := pr.Ext2.Sub(&p1.X, xr) + yr = pr.Ext2.Mul(λ, yr) + yr = pr.Ext2.Sub(yr, &p1.Y) + + p.X = *xr + p.Y = *yr + + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &p, &line + +} + +// addStep adds two points in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) addStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation) { + + // compute λ = (y2-y1)/(x2-x1) + p2ypy := pr.Ext2.Sub(&p2.Y, &p1.Y) + p2xpx := pr.Ext2.Sub(&p2.X, &p1.X) + λ := pr.Ext2.DivUnchecked(p2ypy, p2xpx) + + // xr = λ²-x1-x2 + λλ := pr.Ext2.Square(λ) + p2xpx = pr.Ext2.Add(&p1.X, &p2.X) + xr := pr.Ext2.Sub(λλ, p2xpx) + + // yr = λ(x1-xr) - y1 + pxrx := pr.Ext2.Sub(&p1.X, xr) + λpxrx := pr.Ext2.Mul(λ, pxrx) + yr := pr.Ext2.Sub(λpxrx, &p1.Y) + + var res G2Affine + res.X = *xr + res.Y = *yr + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &res, &line + +} + +// tripleStep triples p1 in affine coordinates, and evaluates the line in Miller loop +func (pr Pairing) tripleStep(p1 *G2Affine) (*G2Affine, *lineEvaluation, *lineEvaluation) { + + var line1, line2 lineEvaluation + var res G2Affine + + // λ1 = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ1 := pr.Ext2.DivUnchecked(n, d) + + // compute line1 + line1.R0 = *pr.Ext2.Neg(λ1) + line1.R1 = *pr.Ext2.Mul(λ1, &p1.X) + line1.R1 = *pr.Ext2.Sub(&line1.R1, &p1.Y) + + // x2 = λ1²-2x + x2 := pr.Ext2.Square(λ1) + x2 = pr.Ext2.Sub(x2, &p1.X) + x2 = pr.Ext2.Sub(x2, &p1.X) + + // ommit yr computation, and + // compute λ2 = 2y/(x2 − x) − λ1. + x1x2 := pr.Ext2.Sub(&p1.X, x2) + λ2 := pr.Ext2.DivUnchecked(d, x1x2) + λ2 = pr.Ext2.Sub(λ2, λ1) + + // compute line2 + line2.R0 = *pr.Ext2.Neg(λ2) + line2.R1 = *pr.Ext2.Mul(λ2, &p1.X) + line2.R1 = *pr.Ext2.Sub(&line2.R1, &p1.Y) + + // xr = λ²-p.x-x2 + λ2λ2 := pr.Ext2.Mul(λ2, λ2) + qxrx := pr.Ext2.Add(x2, &p1.X) + xr := pr.Ext2.Sub(λ2λ2, qxrx) + + // yr = λ(p.x-xr) - p.y + pxrx := pr.Ext2.Sub(&p1.X, xr) + λ2pxrx := pr.Ext2.Mul(λ2, pxrx) + yr := pr.Ext2.Sub(λ2pxrx, &p1.Y) + + res.X = *xr + res.Y = *yr + + return &res, &line1, &line2 +} + +// tangentCompute computes the line that goes through p1 and p2 but does not compute p1+p2 +func (pr Pairing) tangentCompute(p1 *G2Affine) *lineEvaluation { + + // λ = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ := pr.Ext2.DivUnchecked(n, d) + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &line + +} diff --git a/std/algebra/emulated/sw_bls12381/pairing_test.go b/std/algebra/emulated/sw_bls12381/pairing_test.go new file mode 100644 index 0000000000..5a104df470 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/pairing_test.go @@ -0,0 +1,230 @@ +package sw_bls12381 + +import ( + "crypto/rand" + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +func randomG1G2Affines(assert *test.Assert) (bls12381.G1Affine, bls12381.G2Affine) { + _, _, G1AffGen, G2AffGen := bls12381.Generators() + mod := bls12381.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + assert.NoError(err) + s2, err := rand.Int(rand.Reader, mod) + assert.NoError(err) + var p bls12381.G1Affine + p.ScalarMultiplication(&G1AffGen, s1) + var q bls12381.G2Affine + q.ScalarMultiplication(&G2AffGen, s2) + return p, q +} + +type FinalExponentiationCircuit struct { + InGt GTEl + Res GTEl +} + +func (c *FinalExponentiationCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res1 := pairing.FinalExponentiation(&c.InGt) + pairing.AssertIsEqual(res1, &c.Res) + res2 := pairing.FinalExponentiationUnsafe(&c.InGt) + pairing.AssertIsEqual(res2, &c.Res) + return nil +} + +func TestFinalExponentiationTestSolve(t *testing.T) { + assert := test.NewAssert(t) + var gt bls12381.GT + gt.SetRandom() + res := bls12381.FinalExponentiation(>) + witness := FinalExponentiationCircuit{ + InGt: NewGTEl(gt), + Res: NewGTEl(res), + } + err := test.IsSolved(&FinalExponentiationCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairCircuit struct { + InG1 G1Affine + InG2 G2Affine + Res GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + res, err := pairing.Pair([]*G1Affine{&c.InG1}, []*G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines(assert) + res, err := bls12381.Pair([]bls12381.G1Affine{p}, []bls12381.G2Affine{q}) + assert.NoError(err) + witness := PairCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + Res: NewGTEl(res), + } + err = test.IsSolved(&PairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type MultiPairCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine + Res GTEl +} + +func (c *MultiPairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.In1G1) + pairing.AssertIsOnG1(&c.In2G1) + pairing.AssertIsOnG2(&c.In1G2) + pairing.AssertIsOnG2(&c.In2G2) + res, err := pairing.Pair([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestMultiPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines(assert) + p2, q2 := randomG1G2Affines(assert) + res, err := bls12381.Pair([]bls12381.G1Affine{p1, p1, p2, p2}, []bls12381.G2Affine{q1, q2, q1, q2}) + assert.NoError(err) + witness := MultiPairCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + Res: NewGTEl(res), + } + err = test.IsSolved(&MultiPairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairingCheckCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine +} + +func (c *PairingCheckCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + return nil +} + +func TestPairingCheckTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines(assert) + _, q2 := randomG1G2Affines(assert) + var p2 bls12381.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + err := test.IsSolved(&PairingCheckCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type FinalExponentiationSafeCircuit struct { + P1, P2 G1Affine + Q1, Q2 G2Affine +} + +func (c *FinalExponentiationSafeCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return err + } + res, err := pairing.MillerLoop([]*G1Affine{&c.P1, &c.P2}, []*G2Affine{&c.Q1, &c.Q2}) + if err != nil { + return err + } + res2 := pairing.FinalExponentiation(res) + one := pairing.Ext12.One() + pairing.AssertIsEqual(one, res2) + return nil +} + +func TestFinalExponentiationSafeCircuit(t *testing.T) { + assert := test.NewAssert(t) + _, _, p1, q1 := bls12381.Generators() + var p2 bls12381.G1Affine + var q2 bls12381.G2Affine + p2.Neg(&p1) + q2.Set(&q1) + err := test.IsSolved(&FinalExponentiationSafeCircuit{}, &FinalExponentiationSafeCircuit{ + P1: NewG1Affine(p1), + P2: NewG1Affine(p2), + Q1: NewG2Affine(q1), + Q2: NewG2Affine(q2), + }, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type GroupMembershipCircuit struct { + InG1 G1Affine + InG2 G2Affine +} + +func (c *GroupMembershipCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + return nil +} + +func TestGroupMembershipSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines(assert) + witness := GroupMembershipCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + } + err := test.IsSolved(&GroupMembershipCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/sw_bn254/doc.go b/std/algebra/emulated/sw_bn254/doc.go new file mode 100644 index 0000000000..e66e237ee9 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/doc.go @@ -0,0 +1,6 @@ +// Package sw_bn254 implements G1 and G2 arithmetics and pairing computation over BN254 curve. +// +// The implementation follows [[Housni22]]: "Pairings in Rank-1 Constraint Systems". +// +// [Housni22]: https://eprint.iacr.org/2022/1162 +package sw_bn254 diff --git a/std/algebra/emulated/sw_bn254/doc_test.go b/std/algebra/emulated/sw_bn254/doc_test.go new file mode 100644 index 0000000000..7d8ef6a6cd --- /dev/null +++ b/std/algebra/emulated/sw_bn254/doc_test.go @@ -0,0 +1,105 @@ +package sw_bn254_test + +import ( + "crypto/rand" + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" +) + +type PairCircuit struct { + InG1 sw_bn254.G1Affine + InG2 sw_bn254.G2Affine + Res sw_bn254.GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := sw_bn254.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + // Pair method does not check that the points are in the proper groups. + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + // Compute the pairing + res, err := pairing.Pair([]*sw_bn254.G1Affine{&c.InG1}, []*sw_bn254.G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func ExamplePairing() { + p, q, err := randomG1G2Affines() + if err != nil { + panic(err) + } + res, err := bn254.Pair([]bn254.G1Affine{p}, []bn254.G2Affine{q}) + if err != nil { + panic(err) + } + circuit := PairCircuit{} + witness := PairCircuit{ + InG1: sw_bn254.NewG1Affine(p), + InG2: sw_bn254.NewG2Affine(q), + Res: sw_bn254.NewGTEl(res), + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } +} + +func randomG1G2Affines() (p bn254.G1Affine, q bn254.G2Affine, err error) { + _, _, G1AffGen, G2AffGen := bn254.Generators() + mod := bn254.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + p.ScalarMultiplication(&G1AffGen, s1) + q.ScalarMultiplication(&G2AffGen, s2) + return +} diff --git a/std/algebra/emulated/sw_bn254/g1.go b/std/algebra/emulated/sw_bn254/g1.go new file mode 100644 index 0000000000..69ce54898c --- /dev/null +++ b/std/algebra/emulated/sw_bn254/g1.go @@ -0,0 +1,16 @@ +package sw_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type G1Affine = sw_emulated.AffinePoint[emulated.BN254Fp] + +func NewG1Affine(v bn254.G1Affine) G1Affine { + return G1Affine{ + X: emulated.ValueOf[emulated.BN254Fp](v.X), + Y: emulated.ValueOf[emulated.BN254Fp](v.Y), + } +} diff --git a/std/algebra/emulated/sw_bn254/g2.go b/std/algebra/emulated/sw_bn254/g2.go new file mode 100644 index 0000000000..71026f66f0 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/g2.go @@ -0,0 +1,216 @@ +package sw_bn254 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" + "github.com/consensys/gnark/std/math/emulated" +) + +type G2 struct { + *fields_bn254.Ext2 + w *emulated.Element[emulated.BN254Fp] + u, v *fields_bn254.E2 +} + +type G2Affine struct { + X, Y fields_bn254.E2 +} + +func NewG2(api frontend.API) *G2 { + w := emulated.ValueOf[emulated.BN254Fp]("21888242871839275220042445260109153167277707414472061641714758635765020556616") + u := fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp]("21575463638280843010398324269430826099269044274347216827212613867836435027261"), + A1: emulated.ValueOf[emulated.BN254Fp]("10307601595873709700152284273816112264069230130616436755625194854815875713954"), + } + v := fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp]("2821565182194536844548159561693502659359617185244120367078079554186484126554"), + A1: emulated.ValueOf[emulated.BN254Fp]("3505843767911556378687030309984248845540243509899259641013678093033130930403"), + } + return &G2{ + Ext2: fields_bn254.NewExt2(api), + w: &w, + u: &u, + v: &v, + } +} + +func NewG2Affine(v bn254.G2Affine) G2Affine { + return G2Affine{ + X: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.X.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.X.A1), + }, + Y: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.Y.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.Y.A1), + }, + } +} + +func (g2 *G2) phi(q *G2Affine) *G2Affine { + x := g2.Ext2.MulByElement(&q.X, g2.w) + + return &G2Affine{ + X: *x, + Y: q.Y, + } +} + +func (g2 *G2) psi(q *G2Affine) *G2Affine { + x := g2.Ext2.Conjugate(&q.X) + x = g2.Ext2.Mul(x, g2.u) + y := g2.Ext2.Conjugate(&q.Y) + y = g2.Ext2.Mul(y, g2.v) + + return &G2Affine{ + X: *x, + Y: *y, + } +} + +func (g2 *G2) scalarMulBySeed(q *G2Affine) *G2Affine { + z := g2.double(q) + t0 := g2.add(q, z) + t2 := g2.add(q, t0) + t1 := g2.add(z, t2) + z = g2.doubleAndAdd(t1, t0) + t0 = g2.add(t0, z) + t2 = g2.add(t2, t0) + t1 = g2.add(t1, t2) + t0 = g2.add(t0, t1) + t1 = g2.add(t1, t0) + t0 = g2.add(t0, t1) + t2 = g2.add(t2, t0) + t1 = g2.doubleAndAdd(t2, t1) + t2 = g2.add(t2, t1) + z = g2.add(z, t2) + t2 = g2.add(t2, z) + z = g2.doubleAndAdd(t2, z) + t0 = g2.add(t0, z) + t1 = g2.add(t1, t0) + t3 := g2.double(t1) + t3 = g2.doubleAndAdd(t3, t1) + t2 = g2.add(t2, t3) + t1 = g2.add(t1, t2) + t2 = g2.add(t2, t1) + t2 = g2.doubleN(t2, 16) + t1 = g2.doubleAndAdd(t2, t1) + t1 = g2.doubleN(t1, 13) + t0 = g2.doubleAndAdd(t1, t0) + t0 = g2.doubleN(t0, 15) + z = g2.doubleAndAdd(t0, z) + + return z +} + +func (g2 G2) add(p, q *G2Affine) *G2Affine { + // compute λ = (q.y-p.y)/(q.x-p.x) + qypy := g2.Ext2.Sub(&q.Y, &p.Y) + qxpx := g2.Ext2.Sub(&q.X, &p.X) + λ := g2.Ext2.DivUnchecked(qypy, qxpx) + + // xr = λ²-p.x-q.x + λλ := g2.Ext2.Square(λ) + qxpx = g2.Ext2.Add(&p.X, &q.X) + xr := g2.Ext2.Sub(λλ, qxpx) + + // p.y = λ(p.x-r.x) - p.y + pxrx := g2.Ext2.Sub(&p.X, xr) + λpxrx := g2.Ext2.Mul(λ, pxrx) + yr := g2.Ext2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) neg(p *G2Affine) *G2Affine { + xr := &p.X + yr := g2.Ext2.Neg(&p.Y) + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) sub(p, q *G2Affine) *G2Affine { + qNeg := g2.neg(q) + return g2.add(p, qNeg) +} + +func (g2 *G2) double(p *G2Affine) *G2Affine { + // compute λ = (3p.x²)/2*p.y + xx3a := g2.Square(&p.X) + xx3a = g2.MulByConstElement(xx3a, big.NewInt(3)) + y2 := g2.Double(&p.Y) + λ := g2.DivUnchecked(xx3a, y2) + + // xr = λ²-2p.x + x2 := g2.Double(&p.X) + λλ := g2.Square(λ) + xr := g2.Sub(λλ, x2) + + // yr = λ(p-xr) - p.y + pxrx := g2.Sub(&p.X, xr) + λpxrx := g2.Mul(λ, pxrx) + yr := g2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 *G2) doubleN(p *G2Affine, n int) *G2Affine { + pn := p + for s := 0; s < n; s++ { + pn = g2.double(pn) + } + return pn +} + +func (g2 G2) doubleAndAdd(p, q *G2Affine) *G2Affine { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := g2.Ext2.Sub(&q.Y, &p.Y) + xqxp := g2.Ext2.Sub(&q.X, &p.X) + λ1 := g2.Ext2.DivUnchecked(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := g2.Ext2.Square(λ1) + xqxp = g2.Ext2.Add(&p.X, &q.X) + x2 := g2.Ext2.Sub(λ1λ1, xqxp) + + // ommit y2 computation + // compute λ2 = -λ1-2*p.y/(x2-p.x) + ypyp := g2.Ext2.Add(&p.Y, &p.Y) + x2xp := g2.Ext2.Sub(x2, &p.X) + λ2 := g2.Ext2.DivUnchecked(ypyp, x2xp) + λ2 = g2.Ext2.Add(λ1, λ2) + λ2 = g2.Ext2.Neg(λ2) + + // compute x3 =λ2²-p.x-x3 + λ2λ2 := g2.Ext2.Square(λ2) + x3 := g2.Ext2.Sub(λ2λ2, &p.X) + x3 = g2.Ext2.Sub(x3, x2) + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := g2.Ext2.Sub(&p.X, x3) + y3 = g2.Ext2.Mul(λ2, y3) + y3 = g2.Ext2.Sub(y3, &p.Y) + + return &G2Affine{ + X: *x3, + Y: *y3, + } +} + +// AssertIsEqual asserts that p and q are the same point. +func (g2 *G2) AssertIsEqual(p, q *G2Affine) { + g2.Ext2.AssertIsEqual(&p.X, &q.X) + g2.Ext2.AssertIsEqual(&p.Y, &q.Y) +} diff --git a/std/algebra/emulated/sw_bn254/g2_test.go b/std/algebra/emulated/sw_bn254/g2_test.go new file mode 100644 index 0000000000..3e61f05cc4 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/g2_test.go @@ -0,0 +1,144 @@ +package sw_bn254 + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type addG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *addG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.add(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bn254.G2Affine + res.Add(&in1, &in2) + witness := addG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleG2Circuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *doubleG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.double(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bn254.G2Affine + var in1Jac, resJac bn254.G2Jac + in1Jac.FromAffine(&in1) + resJac.Double(&in1Jac) + res.FromJacobian(&resJac) + witness := doubleG2Circuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleAndAddG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *doubleAndAddG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.doubleAndAdd(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleAndAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bn254.G2Affine + res.Double(&in1). + Add(&res, &in2) + witness := doubleAndAddG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleAndAddG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type scalarMulG2BySeedCircuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *scalarMulG2BySeedCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.scalarMulBySeed(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestScalarMulG2BySeedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bn254.G2Affine + x0, _ := new(big.Int).SetString("4965661367192848881", 10) + res.ScalarMultiplication(&in1, x0) + witness := scalarMulG2BySeedCircuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2BySeedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type endomorphismG2Circuit struct { + In1 G2Affine +} + +func (c *endomorphismG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res1 := g2.phi(&c.In1) + res1 = g2.neg(res1) + res2 := g2.psi(&c.In1) + res2 = g2.psi(res2) + g2.AssertIsEqual(res1, res2) + return nil +} + +func TestEndomorphismG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + witness := endomorphismG2Circuit{ + In1: NewG2Affine(in1), + } + err := test.IsSolved(&endomorphismG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/sw_bn254/pairing.go b/std/algebra/emulated/sw_bn254/pairing.go new file mode 100644 index 0000000000..5ed4e08f81 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/pairing.go @@ -0,0 +1,672 @@ +package sw_bn254 + +import ( + "errors" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type Pairing struct { + api frontend.API + *fields_bn254.Ext12 + curveF *emulated.Field[emulated.BN254Fp] + curve *sw_emulated.Curve[emulated.BN254Fp, emulated.BN254Fr] + g2 *G2 + bTwist *fields_bn254.E2 +} + +type GTEl = fields_bn254.E12 + +func NewGTEl(v bn254.GT) GTEl { + return GTEl{ + C0: fields_bn254.E6{ + B0: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C0.B0.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C0.B0.A1), + }, + B1: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C0.B1.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C0.B1.A1), + }, + B2: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C0.B2.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C0.B2.A1), + }, + }, + C1: fields_bn254.E6{ + B0: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C1.B0.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C1.B0.A1), + }, + B1: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C1.B1.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C1.B1.A1), + }, + B2: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C1.B2.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C1.B2.A1), + }, + }, + } +} + +func NewPairing(api frontend.API) (*Pairing, error) { + ba, err := emulated.NewField[emulated.BN254Fp](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + return nil, fmt.Errorf("new curve: %w", err) + } + bTwist := fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp]("19485874751759354771024239261021720505790618469301721065564631296452457478373"), + A1: emulated.ValueOf[emulated.BN254Fp]("266929791119991161246907387137283842545076965332900288569378510910307636690"), + } + return &Pairing{ + api: api, + Ext12: fields_bn254.NewExt12(api), + curveF: ba, + curve: curve, + g2: NewG2(api), + bTwist: &bTwist, + }, nil +} + +// FinalExponentiation computes the exponentiation eᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r. +// +// We use instead d'= s ⋅ d, where s is the cofactor +// +// 2x₀(6x₀²+3x₀+1) +// +// and r does NOT divide d' +// +// FinalExponentiation returns a decompressed element in E12. +// +// This is the safe version of the method where e may be {-1,1}. If it is known +// that e ≠ {-1,1} then using the unsafe version of the method saves +// considerable amount of constraints. When called with the result of +// [MillerLoop], then current method is applicable when length of the inputs to +// Miller loop is 1. +func (pr Pairing) FinalExponentiation(e *GTEl) *GTEl { + return pr.finalExponentiation(e, false) +} + +// FinalExponentiationUnsafe computes the exponentiation eᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r. +// +// We use instead d'= s ⋅ d, where s is the cofactor +// +// 2x₀(6x₀²+3x₀+1) +// +// and r does NOT divide d' +// +// FinalExponentiationUnsafe returns a decompressed element in E12. +// +// This is the unsafe version of the method where e may NOT be {-1,1}. If e ∈ +// {-1, 1}, then there exists no valid solution to the circuit. This method is +// applicable when called with the result of [MillerLoop] method when the length +// of the inputs to Miller loop is 1. +func (pr Pairing) FinalExponentiationUnsafe(e *GTEl) *GTEl { + return pr.finalExponentiation(e, true) +} + +// finalExponentiation computes the exponentiation eᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r. +// +// We use instead d'= s ⋅ d, where s is the cofactor +// +// 2x₀(6x₀²+3x₀+1) +// +// and r does NOT divide d' +// +// finalExponentiation returns a decompressed element in E12 +func (pr Pairing) finalExponentiation(e *GTEl, unsafe bool) *GTEl { + + // 1. Easy part + // (p⁶-1)(p²+1) + var selector1, selector2 frontend.Variable + _dummy := pr.Ext6.One() + + if unsafe { + // The Miller loop result is ≠ {-1,1}, otherwise this means P and Q are + // linearly dependant and not from G1 and G2 respectively. + // So e ∈ G_{q,2} \ {-1,1} and hence e.C1 ≠ 0. + // Nothing to do. + + } else { + // However, for a product of Miller loops (n>=2) this might happen. If this is + // the case, the result is 1 in the torus. We assign a dummy value (1) to e.C1 + // and proceed further. + selector1 = pr.Ext6.IsZero(&e.C1) + e.C1 = *pr.Ext6.Select(selector1, _dummy, &e.C1) + } + + // Torus compression absorbed: + // Raising e to (p⁶-1) is + // e^(p⁶) / e = (e.C0 - w*e.C1) / (e.C0 + w*e.C1) + // = (-e.C0/e.C1 + w) / (-e.C0/e.C1 - w) + // So the fraction -e.C0/e.C1 is already in the torus. + // This absorbs the torus compression in the easy part. + c := pr.Ext6.DivUnchecked(&e.C0, &e.C1) + c = pr.Ext6.Neg(c) + t0 := pr.FrobeniusSquareTorus(c) + c = pr.MulTorus(t0, c) + + // 2. Hard part (up to permutation) + // 2x₀(6x₀²+3x₀+1)(p⁴-p²+1)/r + // Duquesne and Ghammam + // https://eprint.iacr.org/2015/192.pdf + // Fuentes et al. (alg. 6) + // performed in torus compressed form + t0 = pr.ExptTorus(c) + t0 = pr.InverseTorus(t0) + t0 = pr.SquareTorus(t0) + t1 := pr.SquareTorus(t0) + t1 = pr.MulTorus(t0, t1) + t2 := pr.ExptTorus(t1) + t2 = pr.InverseTorus(t2) + t3 := pr.InverseTorus(t1) + t1 = pr.MulTorus(t2, t3) + t3 = pr.SquareTorus(t2) + t4 := pr.ExptTorus(t3) + t4 = pr.MulTorus(t1, t4) + t3 = pr.MulTorus(t0, t4) + t0 = pr.MulTorus(t2, t4) + t0 = pr.MulTorus(c, t0) + t2 = pr.FrobeniusTorus(t3) + t0 = pr.MulTorus(t2, t0) + t2 = pr.FrobeniusSquareTorus(t4) + t0 = pr.MulTorus(t2, t0) + t2 = pr.InverseTorus(c) + t2 = pr.MulTorus(t2, t3) + t2 = pr.FrobeniusCubeTorus(t2) + + var result GTEl + // MulTorus(t0, t2) requires t0 ≠ -t2. When t0 = -t2, it means the + // product is 1 in the torus. + if unsafe { + // For a single pairing, this does not happen because the pairing is non-degenerate. + result = *pr.DecompressTorus(pr.MulTorus(t2, t0)) + } else { + // For a product of pairings this might happen when the result is expected to be 1. + // We assign a dummy value (1) to t0 and proceed furhter. + // Finally we do a select on both edge cases: + // - Only if seletor1=0 and selector2=0, we return MulTorus(t2, t0) decompressed. + // - Otherwise, we return 1. + _sum := pr.Ext6.Add(t0, t2) + selector2 = pr.Ext6.IsZero(_sum) + t0 = pr.Ext6.Select(selector2, _dummy, t0) + selector := pr.api.Mul(pr.api.Sub(1, selector1), pr.api.Sub(1, selector2)) + result = *pr.Select(selector, pr.DecompressTorus(pr.MulTorus(t2, t0)), pr.One()) + } + + return &result +} + +// Pair calculates the reduced pairing for a set of points +// ∏ᵢ e(Pᵢ, Qᵢ). +// +// This function doesn't check that the inputs are in the correct subgroups. See AssertIsOnG1 and AssertIsOnG2. +func (pr Pairing) Pair(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + res, err := pr.MillerLoop(P, Q) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.finalExponentiation(res, len(P) == 1) + return res, nil +} + +// PairingCheck calculates the reduced pairing for a set of points and asserts if the result is One +// ∏ᵢ e(Pᵢ, Qᵢ) =? 1 +// +// This function doesn't check that the inputs are in the correct subgroups. See AssertIsOnG1 and AssertIsOnG2. +func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { + f, err := pr.Pair(P, Q) + if err != nil { + return err + + } + one := pr.One() + pr.AssertIsEqual(f, one) + + return nil +} + +func (pr Pairing) AssertIsEqual(x, y *GTEl) { + pr.Ext12.AssertIsEqual(x, y) +} + +func (pr Pairing) AssertIsOnCurve(P *G1Affine) { + pr.curve.AssertIsOnCurve(P) +} + +func (pr Pairing) AssertIsOnTwist(Q *G2Affine) { + // Twist: Y² == X³ + aX + b, where a=0 and b=3/(9+u) + // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) + + // if Q=(0,0) we assign b=0 otherwise 3/(9+u), and continue + selector := pr.api.And(pr.Ext2.IsZero(&Q.X), pr.Ext2.IsZero(&Q.Y)) + b := pr.Ext2.Select(selector, pr.Ext2.Zero(), pr.bTwist) + + left := pr.Ext2.Square(&Q.Y) + right := pr.Ext2.Square(&Q.X) + right = pr.Ext2.Mul(right, &Q.X) + right = pr.Ext2.Add(right, b) + pr.Ext2.AssertIsEqual(left, right) +} + +func (pr Pairing) AssertIsOnG1(P *G1Affine) { + // BN254 has a prime order, so we only + // 1- Check P is on the curve + pr.AssertIsOnCurve(P) +} + +func (pr Pairing) AssertIsOnG2(Q *G2Affine) { + // 1- Check Q is on the curve + pr.AssertIsOnTwist(Q) + + // 2- Check Q has the right subgroup order + + // [x₀]Q + xQ := pr.g2.scalarMulBySeed(Q) + // ψ([x₀]Q) + psixQ := pr.g2.psi(xQ) + // ψ²([x₀]Q) = -ϕ([x₀]Q) + psi2xQ := pr.g2.phi(xQ) + // ψ³([2x₀]Q) + psi3xxQ := pr.g2.double(psi2xQ) + psi3xxQ = pr.g2.psi(psi3xxQ) + + // _Q = ψ³([2x₀]Q) - ψ²([x₀]Q) - ψ([x₀]Q) - [x₀]Q + _Q := pr.g2.sub(psi2xQ, psi3xxQ) + _Q = pr.g2.sub(_Q, psixQ) + _Q = pr.g2.sub(_Q, xQ) + + // [r]Q == 0 <==> _Q == Q + pr.g2.AssertIsEqual(Q, _Q) +} + +// loopCounter = 6x₀+2 = 29793968203157093288 +// +// in 2-NAF +var loopCounter = [66]int8{ + 0, 0, 0, 1, 0, 1, 0, -1, 0, 0, -1, + 0, 0, 0, 1, 0, 0, -1, 0, -1, 0, 0, + 0, 1, 0, -1, 0, 0, 0, 0, -1, 0, 0, + 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, + -1, 0, 0, -1, 0, 1, 0, -1, 0, 0, 0, + -1, 0, -1, 0, 0, 0, 1, 0, -1, 0, 1, +} + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 + R0(x/y) + R1(1/y) = 0 instead of R0'*y + R1'*x + R2' = 0 This +// makes the multiplication by lines (MulBy034) and between lines (Mul034By034) +// circuit-efficient. +type lineEvaluation struct { + R0, R1 fields_bn254.E2 +} + +// MillerLoop computes the multi-Miller loop +// ∏ᵢ { fᵢ_{6x₀+2,Q}(P) · ℓᵢ_{[6x₀+2]Q,π(Q)}(P) · ℓᵢ_{[6x₀+2]Q+π(Q),-π²(Q)}(P) } +func (pr Pairing) MillerLoop(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return nil, errors.New("invalid inputs sizes") + } + + res := pr.Ext12.One() + var prodLines [5]fields_bn254.E2 + + var l1, l2 *lineEvaluation + Qacc := make([]*G2Affine, n) + QNeg := make([]*G2Affine, n) + yInv := make([]*emulated.Element[emulated.BN254Fp], n) + xOverY := make([]*emulated.Element[emulated.BN254Fp], n) + + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + QNeg[k] = &G2Affine{X: Q[k].X, Y: *pr.Ext2.Neg(&Q[k].Y)} + // P and Q are supposed to be on G1 and G2 respectively of prime order r. + // The point (x,0) is of order 2. But this function does not check + // subgroup membership. + // Anyway (x,0) cannot be on BN254 because -3 is a cubic non-residue in Fp. + // So, 1/y is well defined for all points P's. + yInv[k] = pr.curveF.Inverse(&P[k].Y) + xOverY[k] = pr.curveF.MulMod(&P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{6x₀+2,Q}(P) } + // i = 64, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line to res) + Qacc[0], l1 = pr.doubleStep(Qacc[0]) + // line evaluation at P[0] + res.C1.B0 = *pr.MulByElement(&l1.R0, xOverY[0]) + res.C1.B1 = *pr.MulByElement(&l1.R1, yInv[0]) + + if n >= 2 { + // k = 1, separately to avoid MulBy034 (res × ℓ) + // (res is also a line at this point, so we use Mul034By034 ℓ × ℓ) + Qacc[1], l1 = pr.doubleStep(Qacc[1]) + + // line evaluation at P[1] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[1]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[1]) + + // ℓ × res + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &res.C1.B0, &res.C1.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B0 = prodLines[3] + res.C1.B1 = prodLines[4] + } + + if n >= 3 { + // k = 2, separately to avoid MulBy034 (res × ℓ) + // (res has a zero E2 element, so we use Mul01234By034) + Qacc[2], l1 = pr.doubleStep(Qacc[2]) + + // line evaluation at P[1] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[2]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[2]) + + // ℓ × res + res = pr.Mul01234By034(&prodLines, &l1.R0, &l1.R1) + + // k >= 3 + for k := 3; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = pr.doubleStep(Qacc[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // ℓ × res + res = pr.MulBy034(res, &l1.R0, &l1.R1) + } + } + + // i = 63, separately to avoid a doubleStep + // (at this point Qacc = 2Q, so 2Qacc-Q=3Q is equivalent to Qacc+Q=3Q + // this means doubleAndAddStep is equivalent to addStep here) + res = pr.Square(res) + for k := 0; k < n; k++ { + // l2 the line passing Qacc[k] and -Q + l2 = pr.lineCompute(Qacc[k], QNeg[k]) + + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // Qacc[k] ← Qacc[k]+Q[k] and + // l1 the line ℓ passing Qacc[k] and Q[k] + Qacc[k], l1 = pr.addStep(Qacc[k], Q[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + } + + for i := 62; i >= 0; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + switch loopCounter[i] { + + case 0: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = pr.doubleStep(Qacc[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // ℓ × res + res = pr.MulBy034(res, &l1.R0, &l1.R1) + } + + case 1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = pr.doubleAndAddStep(Qacc[k], Q[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + case -1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]-Q[k], + // l1 the line ℓ passing Qacc[k] and -Q[k] + // l2 the line ℓ passing (Qacc[k]-Q[k]) and Qacc[k] + Qacc[k], l1, l2 = pr.doubleAndAddStep(Qacc[k], QNeg[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + default: + return nil, errors.New("invalid loopCounter") + } + } + + // Compute ∏ᵢ { ℓᵢ_{[6x₀+2]Q,π(Q)}(P) · ℓᵢ_{[6x₀+2]Q+π(Q),-π²(Q)}(P) } + Q1, Q2 := new(G2Affine), new(G2Affine) + for k := 0; k < n; k++ { + //Q1 = π(Q) + Q1.X = *pr.Ext2.Conjugate(&Q[k].X) + Q1.X = *pr.Ext2.MulByNonResidue1Power2(&Q1.X) + Q1.Y = *pr.Ext2.Conjugate(&Q[k].Y) + Q1.Y = *pr.Ext2.MulByNonResidue1Power3(&Q1.Y) + + // Q2 = -π²(Q) + Q2.X = *pr.Ext2.MulByNonResidue2Power2(&Q[k].X) + Q2.Y = *pr.Ext2.MulByNonResidue2Power3(&Q[k].Y) + Q2.Y = *pr.Ext2.Neg(&Q2.Y) + + // Qacc[k] ← Qacc[k]+π(Q) and + // l1 the line passing Qacc[k] and π(Q) + Qacc[k], l1 = pr.addStep(Qacc[k], Q1) + + // line evaluation at P[k] + l1.R0 = *pr.Ext2.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.Ext2.MulByElement(&l1.R1, yInv[k]) + + // l2 the line passing Qacc[k] and -π²(Q) + l2 = pr.lineCompute(Qacc[k], Q2) + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + return res, nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleAndAddStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation, *lineEvaluation) { + + var line1, line2 lineEvaluation + var p G2Affine + + // compute λ1 = (y2-y1)/(x2-x1) + n := pr.Ext2.Sub(&p1.Y, &p2.Y) + d := pr.Ext2.Sub(&p1.X, &p2.X) + l1 := pr.Ext2.DivUnchecked(n, d) + + // compute x3 =λ1²-x1-x2 + x3 := pr.Ext2.Square(l1) + x3 = pr.Ext2.Sub(x3, &p1.X) + x3 = pr.Ext2.Sub(x3, &p2.X) + + // omit y3 computation + + // compute line1 + line1.R0 = *pr.Ext2.Neg(l1) + line1.R1 = *pr.Ext2.Mul(l1, &p1.X) + line1.R1 = *pr.Ext2.Sub(&line1.R1, &p1.Y) + + // compute λ2 = -λ1-2y1/(x3-x1) + n = pr.Ext2.Double(&p1.Y) + d = pr.Ext2.Sub(x3, &p1.X) + l2 := pr.Ext2.DivUnchecked(n, d) + l2 = pr.Ext2.Add(l2, l1) + l2 = pr.Ext2.Neg(l2) + + // compute x4 = λ2²-x1-x3 + x4 := pr.Ext2.Square(l2) + x4 = pr.Ext2.Sub(x4, &p1.X) + x4 = pr.Ext2.Sub(x4, x3) + + // compute y4 = λ2(x1 - x4)-y1 + y4 := pr.Ext2.Sub(&p1.X, x4) + y4 = pr.Ext2.Mul(l2, y4) + y4 = pr.Ext2.Sub(y4, &p1.Y) + + p.X = *x4 + p.Y = *y4 + + // compute line2 + line2.R0 = *pr.Ext2.Neg(l2) + line2.R1 = *pr.Ext2.Mul(l2, &p1.X) + line2.R1 = *pr.Ext2.Sub(&line2.R1, &p1.Y) + + return &p, &line1, &line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleStep(p1 *G2Affine) (*G2Affine, *lineEvaluation) { + + var p G2Affine + var line lineEvaluation + + // λ = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ := pr.Ext2.DivUnchecked(n, d) + + // xr = λ²-2x + xr := pr.Ext2.Square(λ) + xr = pr.Ext2.Sub(xr, &p1.X) + xr = pr.Ext2.Sub(xr, &p1.X) + + // yr = λ(x-xr)-y + yr := pr.Ext2.Sub(&p1.X, xr) + yr = pr.Ext2.Mul(λ, yr) + yr = pr.Ext2.Sub(yr, &p1.Y) + + p.X = *xr + p.Y = *yr + + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &p, &line + +} + +// addStep adds two points in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) addStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation) { + + // compute λ = (y2-y1)/(x2-x1) + p2ypy := pr.Ext2.Sub(&p2.Y, &p1.Y) + p2xpx := pr.Ext2.Sub(&p2.X, &p1.X) + λ := pr.Ext2.DivUnchecked(p2ypy, p2xpx) + + // xr = λ²-x1-x2 + λλ := pr.Ext2.Square(λ) + p2xpx = pr.Ext2.Add(&p1.X, &p2.X) + xr := pr.Ext2.Sub(λλ, p2xpx) + + // yr = λ(x1-xr) - y1 + pxrx := pr.Ext2.Sub(&p1.X, xr) + λpxrx := pr.Ext2.Mul(λ, pxrx) + yr := pr.Ext2.Sub(λpxrx, &p1.Y) + + var res G2Affine + res.X = *xr + res.Y = *yr + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &res, &line + +} + +// lineCompute computes the line that goes through p1 and p2 but does not compute p1+p2 +func (pr Pairing) lineCompute(p1, p2 *G2Affine) *lineEvaluation { + + // compute λ = (y2-y1)/(x2-x1) + qypy := pr.Ext2.Sub(&p2.Y, &p1.Y) + qxpx := pr.Ext2.Sub(&p2.X, &p1.X) + λ := pr.Ext2.DivUnchecked(qypy, qxpx) + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &line + +} diff --git a/std/algebra/emulated/sw_bn254/pairing_test.go b/std/algebra/emulated/sw_bn254/pairing_test.go new file mode 100644 index 0000000000..66a77219fa --- /dev/null +++ b/std/algebra/emulated/sw_bn254/pairing_test.go @@ -0,0 +1,302 @@ +package sw_bn254 + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +func randomG1G2Affines() (bn254.G1Affine, bn254.G2Affine) { + _, _, G1AffGen, G2AffGen := bn254.Generators() + mod := bn254.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + var p bn254.G1Affine + p.ScalarMultiplication(&G1AffGen, s1) + var q bn254.G2Affine + q.ScalarMultiplication(&G2AffGen, s2) + return p, q +} + +type FinalExponentiationCircuit struct { + InGt GTEl + Res GTEl +} + +func (c *FinalExponentiationCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res1 := pairing.FinalExponentiation(&c.InGt) + pairing.AssertIsEqual(res1, &c.Res) + res2 := pairing.FinalExponentiationUnsafe(&c.InGt) + pairing.AssertIsEqual(res2, &c.Res) + return nil +} + +func TestFinalExponentiationTestSolve(t *testing.T) { + assert := test.NewAssert(t) + var gt bn254.GT + gt.SetRandom() + res := bn254.FinalExponentiation(>) + witness := FinalExponentiationCircuit{ + InGt: NewGTEl(gt), + Res: NewGTEl(res), + } + err := test.IsSolved(&FinalExponentiationCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairCircuit struct { + InG1 G1Affine + InG2 G2Affine + Res GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + res, err := pairing.Pair([]*G1Affine{&c.InG1}, []*G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + res, err := bn254.Pair([]bn254.G1Affine{p}, []bn254.G2Affine{q}) + assert.NoError(err) + witness := PairCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + Res: NewGTEl(res), + } + err = test.IsSolved(&PairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type MultiPairCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine + Res GTEl +} + +func (c *MultiPairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.In1G1) + pairing.AssertIsOnG1(&c.In2G1) + pairing.AssertIsOnG2(&c.In1G2) + pairing.AssertIsOnG2(&c.In2G2) + res, err := pairing.Pair([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestMultiPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines() + p2, q2 := randomG1G2Affines() + res, err := bn254.Pair([]bn254.G1Affine{p1, p1, p2, p2}, []bn254.G2Affine{q1, q2, q1, q2}) + assert.NoError(err) + witness := MultiPairCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + Res: NewGTEl(res), + } + err = test.IsSolved(&MultiPairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairingCheckCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine +} + +func (c *PairingCheckCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + return nil +} + +func TestPairingCheckTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines() + _, q2 := randomG1G2Affines() + var p2 bn254.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + err := test.IsSolved(&PairingCheckCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type FinalExponentiationSafeCircuit struct { + P1, P2 G1Affine + Q1, Q2 G2Affine +} + +func (c *FinalExponentiationSafeCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return err + } + res, err := pairing.MillerLoop([]*G1Affine{&c.P1, &c.P2}, []*G2Affine{&c.Q1, &c.Q2}) + if err != nil { + return err + } + res2 := pairing.FinalExponentiation(res) + one := pairing.Ext12.One() + pairing.AssertIsEqual(one, res2) + return nil +} + +func TestFinalExponentiationSafeCircuit(t *testing.T) { + assert := test.NewAssert(t) + _, _, p1, q1 := bn254.Generators() + var p2 bn254.G1Affine + var q2 bn254.G2Affine + p2.Neg(&p1) + q2.Set(&q1) + err := test.IsSolved(&FinalExponentiationSafeCircuit{}, &FinalExponentiationSafeCircuit{ + P1: NewG1Affine(p1), + P2: NewG1Affine(p2), + Q1: NewG2Affine(q1), + Q2: NewG2Affine(q2), + }, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type GroupMembershipCircuit struct { + InG1 G1Affine + InG2 G2Affine +} + +func (c *GroupMembershipCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + return nil +} + +func TestGroupMembershipSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + witness := GroupMembershipCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + } + err := test.IsSolved(&GroupMembershipCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func BenchmarkPairing(b *testing.B) { + + p1, q1 := randomG1G2Affines() + _, q2 := randomG1G2Affines() + var p2 bn254.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + b.Fatal(err) + } + var ccs constraint.ConstraintSystem + b.Run("compile scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &PairingCheckCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + var buf bytes.Buffer + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("scs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + b.Run("solve scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + b.Run("compile r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &PairingCheckCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + buf.Reset() + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("r1cs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + + b.Run("solve r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/std/algebra/weierstrass/doc.go b/std/algebra/emulated/sw_emulated/doc.go similarity index 68% rename from std/algebra/weierstrass/doc.go rename to std/algebra/emulated/sw_emulated/doc.go index 72ff229dad..dc83a96fcb 100644 --- a/std/algebra/weierstrass/doc.go +++ b/std/algebra/emulated/sw_emulated/doc.go @@ -1,5 +1,5 @@ /* -Package weierstrass implements elliptic curve group operations in (short) +Package sw_emulated implements elliptic curve group operations in (short) Weierstrass form. The elliptic curve is the set of points (X,Y) satisfying the equation: @@ -10,6 +10,10 @@ over some base field 𝐅p for some constants a, b ∈ 𝐅p. Additionally, for every curve we also define its generator (base point) G. All these parameters are stored in the variable of type [CurveParams]. +This package implements unified and complete point addition. The method +[Curve.AddUnified] can be used for point additions or in case of points at +infinity. As such, this package does not expose separate Add and Double methods. + The package provides a few curve parameters, see functions [GetSecp256k1Params] and [GetBN254Params]. @@ -22,12 +26,9 @@ field. For now, we only have a single curve defined on every base field, but this may change in the future with the addition of additional curves. This package uses field emulation (unlike packages -[github.com/consensys/gnark/std/algebra/sw_bls12377] and -[github.com/consensys/gnark/std/algebra/sw_bls24315], which use 2-chains). This +[github.com/consensys/gnark/std/algebra/native/sw_bls12377] and +[github.com/consensys/gnark/std/algebra/native/sw_bls24315], which use 2-chains). This allows to use any curve over any native (SNARK) field. The drawback of this -approach is the extreme cost of the operations. In R1CS, point addition on -256-bit fields is approximately 3500 constraints and doubling is approximately -4300 constraints. A full scalar multiplication is approximately 2M constraints. -It is several times more in PLONKish aritmetisation. +approach is the extreme cost of the operations. */ -package weierstrass +package sw_emulated diff --git a/std/algebra/weierstrass/doc_test.go b/std/algebra/emulated/sw_emulated/doc_test.go similarity index 87% rename from std/algebra/weierstrass/doc_test.go rename to std/algebra/emulated/sw_emulated/doc_test.go index cfa81a1fe0..2d63969ecf 100644 --- a/std/algebra/weierstrass/doc_test.go +++ b/std/algebra/emulated/sw_emulated/doc_test.go @@ -1,4 +1,4 @@ -package weierstrass_test +package sw_emulated_test import ( "fmt" @@ -9,16 +9,16 @@ import ( "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/weierstrass" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" ) type ExampleCurveCircuit[Base, Scalar emulated.FieldParams] struct { - Res weierstrass.AffinePoint[Base] + Res sw_emulated.AffinePoint[Base] } func (c *ExampleCurveCircuit[B, S]) Define(api frontend.API) error { - curve, err := weierstrass.New[B, S](api, weierstrass.GetCurveParams[emulated.BN254Fp]()) + curve, err := sw_emulated.New[B, S](api, sw_emulated.GetCurveParams[emulated.BN254Fp]()) if err != nil { panic("initalize new curve") } @@ -27,7 +27,7 @@ func (c *ExampleCurveCircuit[B, S]) Define(api frontend.API) error { g4 := curve.ScalarMul(G, &scalar4) // 4*G scalar5 := emulated.ValueOf[S](5) g5 := curve.ScalarMul(G, &scalar5) // 5*G - g9 := curve.Add(g4, g5) // 9*G + g9 := curve.AddUnified(g4, g5) // 9*G curve.AssertIsEqual(g9, &c.Res) return nil } @@ -41,7 +41,7 @@ func ExampleCurve() { circuit := ExampleCurveCircuit[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} witness := ExampleCurveCircuit[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - Res: weierstrass.AffinePoint[emulated.Secp256k1Fp]{ + Res: sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), }, diff --git a/std/algebra/emulated/sw_emulated/params.go b/std/algebra/emulated/sw_emulated/params.go new file mode 100644 index 0000000000..efce7a1566 --- /dev/null +++ b/std/algebra/emulated/sw_emulated/params.go @@ -0,0 +1,134 @@ +package sw_emulated + +import ( + "crypto/elliptic" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + "github.com/consensys/gnark/std/math/emulated" +) + +// CurveParams defines parameters of an elliptic curve in short Weierstrass form +// given by the equation +// +// Y² = X³ + aX + b +// +// The base point is defined by (Gx, Gy). +type CurveParams struct { + A *big.Int // a in curve equation + B *big.Int // b in curve equation + Gx *big.Int // base point x + Gy *big.Int // base point y + Gm [][2]*big.Int // m*base point coords +} + +// GetSecp256k1Params returns curve parameters for the curve secp256k1. When +// initialising new curve, use the base field [emulated.Secp256k1Fp] and scalar +// field [emulated.Secp256k1Fr]. +func GetSecp256k1Params() CurveParams { + _, g1aff := secp256k1.Generators() + return CurveParams{ + A: big.NewInt(0), + B: big.NewInt(7), + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeSecp256k1Table(), + } +} + +// GetBN254Params returns the curve parameters for the curve BN254 (alt_bn128). +// When initialising new curve, use the base field [emulated.BN254Fp] and scalar +// field [emulated.BN254Fr]. +func GetBN254Params() CurveParams { + _, _, g1aff, _ := bn254.Generators() + return CurveParams{ + A: big.NewInt(0), + B: big.NewInt(3), + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeBN254Table(), + } +} + +// GetBLS12381Params returns the curve parameters for the curve BLS12-381. +// When initialising new curve, use the base field [emulated.BLS12381Fp] and scalar +// field [emulated.BLS12381Fr]. +func GetBLS12381Params() CurveParams { + _, _, g1aff, _ := bls12381.Generators() + return CurveParams{ + A: big.NewInt(0), + B: big.NewInt(4), + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeBLS12381Table(), + } +} + +// GetP256Params returns the curve parameters for the curve P-256 (also +// SECP256r1). When initialising new curve, use the base field +// [emulated.P256Fp] and scalar field [emulated.P256Fr]. +func GetP256Params() CurveParams { + pr := elliptic.P256().Params() + a := new(big.Int).Sub(pr.P, big.NewInt(3)) + return CurveParams{ + A: a, + B: pr.B, + Gx: pr.Gx, + Gy: pr.Gy, + Gm: computeP256Table(), + } +} + +// GetP384Params returns the curve parameters for the curve P-384 (also +// SECP384r1). When initialising new curve, use the base field +// [emulated.P384Fp] and scalar field [emulated.P384Fr]. +func GetP384Params() CurveParams { + pr := elliptic.P384().Params() + a := new(big.Int).Sub(pr.P, big.NewInt(3)) + return CurveParams{ + A: a, + B: pr.B, + Gx: pr.Gx, + Gy: pr.Gy, + Gm: computeP384Table(), + } +} + +// GetCurveParams returns suitable curve parameters given the parametric type +// Base as base field. It caches the parameters and modifying the values in the +// parameters struct leads to undefined behaviour. +func GetCurveParams[Base emulated.FieldParams]() CurveParams { + var t Base + switch t.Modulus().String() { + case emulated.Secp256k1Fp{}.Modulus().String(): + return secp256k1Params + case emulated.BN254Fp{}.Modulus().String(): + return bn254Params + case emulated.BLS12381Fp{}.Modulus().String(): + return bls12381Params + case emulated.P256Fp{}.Modulus().String(): + return p256Params + case emulated.P384Fp{}.Modulus().String(): + return p384Params + default: + panic("no stored parameters") + } +} + +var ( + secp256k1Params CurveParams + bn254Params CurveParams + bls12381Params CurveParams + p256Params CurveParams + p384Params CurveParams +) + +func init() { + secp256k1Params = GetSecp256k1Params() + bn254Params = GetBN254Params() + bls12381Params = GetBLS12381Params() + p256Params = GetP256Params() + p384Params = GetP384Params() +} diff --git a/std/algebra/emulated/sw_emulated/params_compute.go b/std/algebra/emulated/sw_emulated/params_compute.go new file mode 100644 index 0000000000..5eaf21e87b --- /dev/null +++ b/std/algebra/emulated/sw_emulated/params_compute.go @@ -0,0 +1,132 @@ +package sw_emulated + +import ( + "crypto/elliptic" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/secp256k1" +) + +func computeSecp256k1Table() [][2]*big.Int { + Gjac, _ := secp256k1.Generators() + table := make([][2]*big.Int, 256) + tmp := new(secp256k1.G1Jac).Set(&Gjac) + aff := new(secp256k1.G1Affine) + jac := new(secp256k1.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} + +func computeBN254Table() [][2]*big.Int { + Gjac, _, _, _ := bn254.Generators() + table := make([][2]*big.Int, 256) + tmp := new(bn254.G1Jac).Set(&Gjac) + aff := new(bn254.G1Affine) + jac := new(bn254.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table +} + +func computeBLS12381Table() [][2]*big.Int { + Gjac, _, _, _ := bls12381.Generators() + table := make([][2]*big.Int, 256) + tmp := new(bls12381.G1Jac).Set(&Gjac) + aff := new(bls12381.G1Affine) + jac := new(bls12381.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table +} + +func computeP256Table() [][2]*big.Int { + table := make([][2]*big.Int, 256) + p256 := elliptic.P256() + gx, gy := p256.Params().Gx, p256.Params().Gy + tmpx, tmpy := new(big.Int).Set(gx), new(big.Int).Set(gy) + for i := 1; i < 256; i++ { + tmpx, tmpy = p256.Double(tmpx, tmpy) + switch i { + case 1, 2: + xx, yy := p256.Add(tmpx, tmpy, gx, gy) + table[i-1] = [2]*big.Int{xx, yy} + case 3: + xx, yy := p256.Add(tmpx, tmpy, gx, new(big.Int).Sub(p256.Params().P, gy)) + table[i-1] = [2]*big.Int{xx, yy} + fallthrough + default: + table[i] = [2]*big.Int{tmpx, tmpy} + } + } + return table +} + +func computeP384Table() [][2]*big.Int { + table := make([][2]*big.Int, 384) + p384 := elliptic.P384() + gx, gy := p384.Params().Gx, p384.Params().Gy + tmpx, tmpy := new(big.Int).Set(gx), new(big.Int).Set(gy) + for i := 1; i < 384; i++ { + tmpx, tmpy = p384.Double(tmpx, tmpy) + switch i { + case 1, 2: + xx, yy := p384.Add(tmpx, tmpy, gx, gy) + table[i-1] = [2]*big.Int{xx, yy} + case 3: + xx, yy := p384.Add(tmpx, tmpy, gx, new(big.Int).Sub(p384.Params().P, gy)) + table[i-1] = [2]*big.Int{xx, yy} + fallthrough + default: + table[i] = [2]*big.Int{tmpx, tmpy} + } + } + return table +} diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go new file mode 100644 index 0000000000..7338bff86c --- /dev/null +++ b/std/algebra/emulated/sw_emulated/point.go @@ -0,0 +1,533 @@ +package sw_emulated + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +// New returns a new [Curve] instance over the base field Base and scalar field +// Scalars defined by the curve parameters params. It returns an error if +// initialising the field emulation fails (for example, when the native field is +// too small) or when the curve parameters are incompatible with the fields. +func New[Base, Scalars emulated.FieldParams](api frontend.API, params CurveParams) (*Curve[Base, Scalars], error) { + ba, err := emulated.NewField[Base](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + sa, err := emulated.NewField[Scalars](api) + if err != nil { + return nil, fmt.Errorf("new scalar api: %w", err) + } + emuGm := make([]AffinePoint[Base], len(params.Gm)) + for i, v := range params.Gm { + emuGm[i] = AffinePoint[Base]{emulated.ValueOf[Base](v[0]), emulated.ValueOf[Base](v[1])} + } + Gx := emulated.ValueOf[Base](params.Gx) + Gy := emulated.ValueOf[Base](params.Gy) + return &Curve[Base, Scalars]{ + params: params, + api: api, + baseApi: ba, + scalarApi: sa, + g: AffinePoint[Base]{ + X: Gx, + Y: Gy, + }, + gm: emuGm, + a: emulated.ValueOf[Base](params.A), + b: emulated.ValueOf[Base](params.B), + addA: params.A.Cmp(big.NewInt(0)) != 0, + }, nil +} + +// Curve is an initialised curve which allows performing group operations. +type Curve[Base, Scalars emulated.FieldParams] struct { + // params is the parameters of the curve + params CurveParams + // api is the native api, we construct it ourselves to be sure + api frontend.API + // baseApi is the api for point operations + baseApi *emulated.Field[Base] + // scalarApi is the api for scalar operations + scalarApi *emulated.Field[Scalars] + + // g is the generator (base point) of the curve. + g AffinePoint[Base] + + // gm are the pre-computed multiples the generator (base point) of the curve. + gm []AffinePoint[Base] + + a emulated.Element[Base] + b emulated.Element[Base] + addA bool +} + +// Generator returns the base point of the curve. The method does not copy and +// modifying the returned element leads to undefined behaviour! +func (c *Curve[B, S]) Generator() *AffinePoint[B] { + return &c.g +} + +// GeneratorMultiples returns the pre-computed multiples of the base point of the curve. The method does not copy and +// modifying the returned element leads to undefined behaviour! +func (c *Curve[B, S]) GeneratorMultiples() []AffinePoint[B] { + return c.gm +} + +// AffinePoint represents a point on the elliptic curve. We do not check that +// the point is actually on the curve. +// +// Point (0,0) represents point at the infinity. This representation is +// compatible with the EVM representations of points at infinity. +type AffinePoint[Base emulated.FieldParams] struct { + X, Y emulated.Element[Base] +} + +// Neg returns an inverse of p. It doesn't modify p. +func (c *Curve[B, S]) Neg(p *AffinePoint[B]) *AffinePoint[B] { + return &AffinePoint[B]{ + X: p.X, + Y: *c.baseApi.Neg(&p.Y), + } +} + +// AssertIsEqual asserts that p and q are the same point. +func (c *Curve[B, S]) AssertIsEqual(p, q *AffinePoint[B]) { + c.baseApi.AssertIsEqual(&p.X, &q.X) + c.baseApi.AssertIsEqual(&p.Y, &q.Y) +} + +// add adds p and q and returns it. It doesn't modify p nor q. +// +// ⚠️ p must be different than q and -q, and both nonzero. +// +// It uses incomplete formulas in affine coordinates. +func (c *Curve[B, S]) add(p, q *AffinePoint[B]) *AffinePoint[B] { + // compute λ = (q.y-p.y)/(q.x-p.x) + qypy := c.baseApi.Sub(&q.Y, &p.Y) + qxpx := c.baseApi.Sub(&q.X, &p.X) + λ := c.baseApi.Div(qypy, qxpx) + + // xr = λ²-p.x-q.x + λλ := c.baseApi.MulMod(λ, λ) + qxpx = c.baseApi.Add(&p.X, &q.X) + xr := c.baseApi.Sub(λλ, qxpx) + + // p.y = λ(p.x-r.x) - p.y + pxrx := c.baseApi.Sub(&p.X, xr) + λpxrx := c.baseApi.MulMod(λ, pxrx) + yr := c.baseApi.Sub(λpxrx, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } +} + +// AssertIsOnCurve asserts if p belongs to the curve. It doesn't modify p. +func (c *Curve[B, S]) AssertIsOnCurve(p *AffinePoint[B]) { + // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) + + // if p=(0,0) we assign b=0 and continue + selector := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y)) + b := c.baseApi.Select(selector, c.baseApi.Zero(), &c.b) + + left := c.baseApi.Mul(&p.Y, &p.Y) + right := c.baseApi.Mul(&p.X, c.baseApi.Mul(&p.X, &p.X)) + right = c.baseApi.Add(right, b) + if c.addA { + ax := c.baseApi.Mul(&c.a, &p.X) + right = c.baseApi.Add(right, ax) + } + c.baseApi.AssertIsEqual(left, right) +} + +// AddUnified adds p and q and returns it. It doesn't modify p nor q. +// +// ✅ p can be equal to q, and either or both can be (0,0). +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// It uses the unified formulas of Brier and Joye ([[BriJoy02]] (Corollary 1)). +// +// [BriJoy02]: https://link.springer.com/content/pdf/10.1007/3-540-45664-3_24.pdf +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) AddUnified(p, q *AffinePoint[B]) *AffinePoint[B] { + + // selector1 = 1 when p is (0,0) and 0 otherwise + selector1 := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y)) + // selector2 = 1 when q is (0,0) and 0 otherwise + selector2 := c.api.And(c.baseApi.IsZero(&q.X), c.baseApi.IsZero(&q.Y)) + + // λ = ((p.x+q.x)² - p.x*q.x + a)/(p.y + q.y) + pxqx := c.baseApi.MulMod(&p.X, &q.X) + pxplusqx := c.baseApi.Add(&p.X, &q.X) + num := c.baseApi.MulMod(pxplusqx, pxplusqx) + num = c.baseApi.Sub(num, pxqx) + if c.addA { + num = c.baseApi.Add(num, &c.a) + } + denum := c.baseApi.Add(&p.Y, &q.Y) + // if p.y + q.y = 0, assign dummy 1 to denum and continue + selector3 := c.baseApi.IsZero(denum) + denum = c.baseApi.Select(selector3, c.baseApi.One(), denum) + λ := c.baseApi.Div(num, denum) + + // x = λ^2 - p.x - q.x + xr := c.baseApi.MulMod(λ, λ) + xr = c.baseApi.Sub(xr, pxplusqx) + + // y = λ(p.x - xr) - p.y + yr := c.baseApi.Sub(&p.X, xr) + yr = c.baseApi.MulMod(yr, λ) + yr = c.baseApi.Sub(yr, &p.Y) + result := AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } + + zero := c.baseApi.Zero() + infinity := AffinePoint[B]{X: *zero, Y: *zero} + // if p=(0,0) return q + result = *c.Select(selector1, q, &result) + // if q=(0,0) return p + result = *c.Select(selector2, p, &result) + // if p.y + q.y = 0, return (0, 0) + result = *c.Select(selector3, &infinity, &result) + + return &result +} + +// double doubles p and return it. It doesn't modify p. +// +// ⚠️ p.Y must be nonzero. +// +// It uses affine coordinates. +func (c *Curve[B, S]) double(p *AffinePoint[B]) *AffinePoint[B] { + + // compute λ = (3p.x²+a)/2*p.y, here we assume a=0 (j invariant 0 curve) + xx3a := c.baseApi.MulMod(&p.X, &p.X) + xx3a = c.baseApi.MulConst(xx3a, big.NewInt(3)) + if c.addA { + xx3a = c.baseApi.Add(xx3a, &c.a) + } + y2 := c.baseApi.MulConst(&p.Y, big.NewInt(2)) + λ := c.baseApi.Div(xx3a, y2) + + // xr = λ²-2p.x + x2 := c.baseApi.MulConst(&p.X, big.NewInt(2)) + λλ := c.baseApi.MulMod(λ, λ) + xr := c.baseApi.Sub(λλ, x2) + + // yr = λ(p-xr) - p.y + pxrx := c.baseApi.Sub(&p.X, xr) + λpxrx := c.baseApi.MulMod(λ, pxrx) + yr := c.baseApi.Sub(λpxrx, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } +} + +// triple triples p and return it. It follows [ELM03] (Section 3.1). +// Saves the computation of the y coordinate of 2p as it is used only in the computation of λ2, +// which can be computed as +// +// λ2 = -λ1-2*p.y/(x2-p.x) +// +// instead. It doesn't modify p. +// +// ⚠️ p.Y must be nonzero. +// +// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf +func (c *Curve[B, S]) triple(p *AffinePoint[B]) *AffinePoint[B] { + + // compute λ1 = (3p.x²+a)/2p.y, here we assume a=0 (j invariant 0 curve) + xx := c.baseApi.MulMod(&p.X, &p.X) + xx = c.baseApi.MulConst(xx, big.NewInt(3)) + if c.addA { + xx = c.baseApi.Add(xx, &c.a) + } + y2 := c.baseApi.MulConst(&p.Y, big.NewInt(2)) + λ1 := c.baseApi.Div(xx, y2) + + // xr = λ1²-2p.x + x2 := c.baseApi.MulConst(&p.X, big.NewInt(2)) + λ1λ1 := c.baseApi.MulMod(λ1, λ1) + x2 = c.baseApi.Sub(λ1λ1, x2) + + // ommit y2 computation, and + // compute λ2 = 2p.y/(x2 − p.x) − λ1. + x1x2 := c.baseApi.Sub(&p.X, x2) + λ2 := c.baseApi.Div(y2, x1x2) + λ2 = c.baseApi.Sub(λ2, λ1) + + // xr = λ²-p.x-x2 + λ2λ2 := c.baseApi.MulMod(λ2, λ2) + qxrx := c.baseApi.Add(x2, &p.X) + xr := c.baseApi.Sub(λ2λ2, qxrx) + + // yr = λ(p.x-xr) - p.y + pxrx := c.baseApi.Sub(&p.X, xr) + λ2pxrx := c.baseApi.MulMod(λ2, pxrx) + yr := c.baseApi.Sub(λ2pxrx, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } +} + +// doubleAndAdd computes 2p+q as (p+q)+p. It follows [ELM03] (Section 3.1) +// Saves the computation of the y coordinate of p+q as it is used only in the computation of λ2, +// which can be computed as +// +// λ2 = -λ1-2*p.y/(x2-p.x) +// +// instead. It doesn't modify p nor q. +// +// ⚠️ p must be different than q and -q, and both nonzero. +// +// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf +func (c *Curve[B, S]) doubleAndAdd(p, q *AffinePoint[B]) *AffinePoint[B] { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := c.baseApi.Sub(&q.Y, &p.Y) + xqxp := c.baseApi.Sub(&q.X, &p.X) + λ1 := c.baseApi.Div(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := c.baseApi.MulMod(λ1, λ1) + xqxp = c.baseApi.Add(&p.X, &q.X) + x2 := c.baseApi.Sub(λ1λ1, xqxp) + + // ommit y2 computation + // compute λ2 = -λ1-2*p.y/(x2-p.x) + ypyp := c.baseApi.Add(&p.Y, &p.Y) + x2xp := c.baseApi.Sub(x2, &p.X) + λ2 := c.baseApi.Div(ypyp, x2xp) + λ2 = c.baseApi.Add(λ1, λ2) + λ2 = c.baseApi.Neg(λ2) + + // compute x3 =λ2²-p.x-x3 + λ2λ2 := c.baseApi.MulMod(λ2, λ2) + x3 := c.baseApi.Sub(λ2λ2, &p.X) + x3 = c.baseApi.Sub(x3, x2) + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := c.baseApi.Sub(&p.X, x3) + y3 = c.baseApi.Mul(λ2, y3) + y3 = c.baseApi.Sub(y3, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(x3), + Y: *c.baseApi.Reduce(y3), + } + +} + +// Select selects between p and q given the selector b. If b == 1, then returns +// p and q otherwise. +func (c *Curve[B, S]) Select(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { + x := c.baseApi.Select(b, &p.X, &q.X) + y := c.baseApi.Select(b, &p.Y, &q.Y) + return &AffinePoint[B]{ + X: *x, + Y: *y, + } +} + +// Lookup2 performs a 2-bit lookup between i0, i1, i2, i3 based on bits b0 +// and b1. Returns: +// - i0 if b0=0 and b1=0, +// - i1 if b0=1 and b1=0, +// - i2 if b0=0 and b1=1, +// - i3 if b0=1 and b1=1. +func (c *Curve[B, S]) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 *AffinePoint[B]) *AffinePoint[B] { + x := c.baseApi.Lookup2(b0, b1, &i0.X, &i1.X, &i2.X, &i3.X) + y := c.baseApi.Lookup2(b0, b1, &i0.Y, &i1.Y, &i2.Y, &i3.Y) + return &AffinePoint[B]{ + X: *x, + Y: *y, + } +} + +// ScalarMul computes s * p and returns it. It doesn't modify p nor s. +// This function doesn't check that the p is on the curve. See AssertIsOnCurve. +// +// ✅ p can can be (0,0) and s can be 0. +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// It computes the standard little-endian variable-base double-and-add algorithm +// [HMV04] (Algorithm 3.26). +// +// Since we use incomplete formulas for the addition law, we need to start with +// a non-zero accumulator point (res). To do this, we skip the LSB (bit at +// position 0) and proceed assuming it was 1. At the end, we conditionally +// subtract the initial value (p) if LSB is 1. We also handle the bits at +// positions 1, n-2 and n-1 outside of the loop to optimize the number of +// constraints using [ELM03] (Section 3.1) +// +// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf +// [HMV04]: https://link.springer.com/book/10.1007/b97644 +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *AffinePoint[B] { + + // if p=(0,0) we assign a dummy (0,1) to p and continue + selector := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y)) + one := c.baseApi.One() + p = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, p) + + var st S + sr := c.scalarApi.Reduce(s) + sBits := c.scalarApi.ToBits(sr) + n := st.Modulus().BitLen() + + // i = 1 + tmp := c.triple(p) + res := c.Select(sBits[1], tmp, p) + acc := c.add(tmp, p) + + for i := 2; i <= n-3; i++ { + tmp := c.add(res, acc) + res = c.Select(sBits[i], tmp, res) + acc = c.double(acc) + } + + // i = n-2 + tmp = c.add(res, acc) + res = c.Select(sBits[n-2], tmp, res) + + // i = n-1 + tmp = c.doubleAndAdd(acc, res) + res = c.Select(sBits[n-1], tmp, res) + + // i = 0 + // we use AddUnified here instead of Add so that when s=0, res=(0,0) + // because AddUnified(p, -p) = (0,0) + tmp = c.AddUnified(res, c.Neg(p)) + res = c.Select(sBits[0], res, tmp) + + // if p=(0,0), return (0,0) + zero := c.baseApi.Zero() + res = c.Select(selector, &AffinePoint[B]{X: *zero, Y: *zero}, res) + + return res +} + +// ScalarMulBase computes s * g and returns it, where g is the fixed generator. +// It doesn't modify s. +// +// ✅ When s=0, it returns (0,0). +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// It computes the standard little-endian fixed-base double-and-add algorithm +// [HMV04] (Algorithm 3.26). +// +// The method proceeds similarly to ScalarMul but with the points [2^i]g +// precomputed. The bits at positions 1 and 2 are handled outside of the loop +// to optimize the number of constraints using a Lookup2 with pre-computed +// [3]g, [5]g and [7]g points. +// +// [HMV04]: https://link.springer.com/book/10.1007/b97644 +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S]) *AffinePoint[B] { + g := c.Generator() + gm := c.GeneratorMultiples() + + var st S + sr := c.scalarApi.Reduce(s) + sBits := c.scalarApi.ToBits(sr) + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res := c.Lookup2(sBits[1], sBits[2], g, &gm[0], &gm[1], &gm[2]) + + for i := 3; i < st.Modulus().BitLen(); i++ { + // gm[i] = [2^i]g + tmp := c.add(res, &gm[i]) + res = c.Select(sBits[i], tmp, res) + } + + // i = 0 + tmp := c.AddUnified(res, c.Neg(g)) + res = c.Select(sBits[0], res, tmp) + + return res +} + +// JointScalarMulBase computes s2 * p + s1 * g and returns it, where g is the +// fixed generator. It doesn't modify p, s1 and s2. +// +// ⚠️ p must NOT be (0,0). +// ⚠️ s1 and s2 must NOT be 0. +// +// JointScalarMulBase is used to verify an ECDSA signature (r,s) on the +// secp256k1 curve. In this case, p is a public key, s2=r/s and s1=hash/s. +// - hash cannot be 0, because of pre-image resistance. +// - r cannot be 0, because r is the x coordinate of a random point on +// secp256k1 (y²=x³+7 mod p) and 7 is not a square mod p. For any other +// curve, (_,0) is a point of order 2 which is not the prime subgroup. +// - (0,0) is not a valid public key. +// +// The [EVM] specifies these checks, wich are performed on the zkEVM +// arithmetization side before calling the circuit that uses this method. +// +// This saves the Select logic related to (0,0) and the use of AddUnified to +// handle the 0-scalar edge case. +func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Element[S]) *AffinePoint[B] { + g := c.Generator() + gm := c.GeneratorMultiples() + + var st S + s1r := c.scalarApi.Reduce(s1) + s1Bits := c.scalarApi.ToBits(s1r) + s2r := c.scalarApi.Reduce(s2) + s2Bits := c.scalarApi.ToBits(s2r) + n := st.Modulus().BitLen() + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res1 := c.Lookup2(s1Bits[1], s1Bits[2], g, &gm[0], &gm[1], &gm[2]) + tmp2 := c.triple(p) + res2 := c.Select(s2Bits[1], tmp2, p) + acc := c.add(tmp2, p) + tmp2 = c.add(res2, acc) + res2 = c.Select(s2Bits[2], tmp2, res2) + acc = c.double(acc) + + for i := 3; i <= n-3; i++ { + // gm[i] = [2^i]g + tmp1 := c.add(res1, &gm[i]) + res1 = c.Select(s1Bits[i], tmp1, res1) + tmp2 = c.add(res2, acc) + res2 = c.Select(s2Bits[i], tmp2, res2) + acc = c.double(acc) + } + + // i = 0 + tmp1 := c.add(res1, c.Neg(g)) + res1 = c.Select(s1Bits[0], res1, tmp1) + tmp2 = c.add(res2, c.Neg(p)) + res2 = c.Select(s2Bits[0], res2, tmp2) + + // i = n-2 + tmp1 = c.add(res1, &gm[n-2]) + res1 = c.Select(s1Bits[n-2], tmp1, res1) + tmp2 = c.add(res2, acc) + res2 = c.Select(s2Bits[n-2], tmp2, res2) + + // i = n-1 + tmp1 = c.add(res1, &gm[n-1]) + res1 = c.Select(s1Bits[n-1], tmp1, res1) + tmp2 = c.doubleAndAdd(acc, res2) + res2 = c.Select(s2Bits[n-1], tmp2, res2) + + return c.add(res1, res2) +} diff --git a/std/algebra/emulated/sw_emulated/point_test.go b/std/algebra/emulated/sw_emulated/point_test.go new file mode 100644 index 0000000000..e006a8d7f2 --- /dev/null +++ b/std/algebra/emulated/sw_emulated/point_test.go @@ -0,0 +1,727 @@ +package sw_emulated + +import ( + "crypto/elliptic" + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + fr_bls381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bn254" + fr_bn "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + fp_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + fr_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +var testCurve = ecc.BN254 + +type NegTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] +} + +func (c *NegTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.Neg(&c.P) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestNeg(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var yn fp_secp.Element + yn.Neg(&g.Y) + circuit := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](yn), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type AddTest[T, S emulated.FieldParams] struct { + P, Q, R AffinePoint[T] +} + +func (c *AddTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res1 := cr.add(&c.P, &c.Q) + res2 := cr.AddUnified(&c.P, &c.Q) + cr.AssertIsEqual(res1, &c.R) + cr.AssertIsEqual(res2, &c.R) + return nil +} + +func TestAdd(t *testing.T) { + assert := test.NewAssert(t) + var dJac, aJac secp256k1.G1Jac + g, _ := secp256k1.Generators() + dJac.Double(&g) + aJac.Set(&dJac). + AddAssign(&g) + var dAff, aAff secp256k1.G1Affine + dAff.FromJacobian(&dJac) + aAff.FromJacobian(&aJac) + circuit := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](aAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](aAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type DoubleTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] +} + +func (c *DoubleTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res1 := cr.double(&c.P) + res2 := cr.AddUnified(&c.P, &c.P) + cr.AssertIsEqual(res1, &c.Q) + cr.AssertIsEqual(res2, &c.Q) + return nil +} + +func TestDouble(t *testing.T) { + assert := test.NewAssert(t) + g, _ := secp256k1.Generators() + var dJac secp256k1.G1Jac + dJac.Double(&g) + var dAff secp256k1.G1Affine + dAff.FromJacobian(&dJac) + circuit := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type TripleTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] +} + +func (c *TripleTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.triple(&c.P) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestTriple(t *testing.T) { + assert := test.NewAssert(t) + g, _ := secp256k1.Generators() + var dJac secp256k1.G1Jac + dJac.Double(&g).AddAssign(&g) + var dAff secp256k1.G1Affine + dAff.FromJacobian(&dJac) + circuit := TripleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := TripleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type DoubleAndAddTest[T, S emulated.FieldParams] struct { + P, Q, R AffinePoint[T] +} + +func (c *DoubleAndAddTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.doubleAndAdd(&c.P, &c.Q) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestDoubleAndAdd(t *testing.T) { + assert := test.NewAssert(t) + var pJac, qJac, rJac secp256k1.G1Jac + g, _ := secp256k1.Generators() + pJac.Double(&g) + qJac.Set(&g) + rJac.Double(&pJac). + AddAssign(&qJac) + var pAff, qAff, rAff secp256k1.G1Affine + pAff.FromJacobian(&pJac) + qAff.FromJacobian(&qJac) + rAff.FromJacobian(&rJac) + circuit := DoubleAndAddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := DoubleAndAddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](pAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](pAff.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](qAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](qAff.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](rAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](rAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type AddUnifiedEdgeCases[T, S emulated.FieldParams] struct { + P, Q, R AffinePoint[T] +} + +func (c *AddUnifiedEdgeCases[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.AddUnified(&c.P, &c.Q) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestAddUnifiedEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + var infinity bn254.G1Affine + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S, Sn bn254.G1Affine + S.ScalarMultiplication(&g, s) + Sn.Neg(&S) + + circuit := AddUnifiedEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{} + + // (0,0) + (0,0) == (0,0) + witness1 := AddUnifiedEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // S + (0,0) == S + witness2 := AddUnifiedEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) + S == S + witness3 := AddUnifiedEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + + // S + (-S) == (0,0) + witness4 := AddUnifiedEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Sn.X), + Y: emulated.ValueOf[emulated.BN254Fp](Sn.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // (-S) + S == (0,0) + witness5 := AddUnifiedEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Sn.X), + Y: emulated.ValueOf[emulated.BN254Fp](Sn.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulBaseTest[T, S emulated.FieldParams] struct { + Q AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulBaseTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.ScalarMulBase(&c.S) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestScalarMulBase(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulBaseTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulBase2(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bn254.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseTest[emulated.BN254Fp, emulated.BN254Fr]{} + witness := ScalarMulBaseTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulBase3(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bls12381.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bls12381.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{} + witness := ScalarMulBaseTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + S: emulated.ValueOf[emulated.BLS12381Fr](s), + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](S.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.ScalarMul(&c.P, &c.S) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestScalarMul(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul2(t *testing.T) { + assert := test.NewAssert(t) + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var res bn254.G1Affine + _, _, gen, _ := bn254.Generators() + res.ScalarMultiplication(&gen, s) + + circuit := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{} + witness := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](gen.X), + Y: emulated.ValueOf[emulated.BN254Fp](gen.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](res.X), + Y: emulated.ValueOf[emulated.BN254Fp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulEdgeCases[T, S emulated.FieldParams] struct { + P, R AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulEdgeCases[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.ScalarMul(&c.P, &c.S) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestScalarMulEdgeCasesEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + var infinity bn254.G1Affine + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bn254.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * S == (0,0) + witness2 := ScalarMulEdgeCases[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](new(big.Int)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul3(t *testing.T) { + assert := test.NewAssert(t) + var r fr_bls381.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var res bls12381.G1Affine + _, _, gen, _ := bls12381.Generators() + res.ScalarMultiplication(&gen, s) + + circuit := ScalarMulTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{} + witness := ScalarMulTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + S: emulated.ValueOf[emulated.BLS12381Fr](s), + P: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](gen.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](gen.Y), + }, + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](res.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul4(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + px, py := p256.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulTest[emulated.P256Fp, emulated.P256Fr]{} + witness := ScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](s), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p256.Params().Gx), + Y: emulated.ValueOf[emulated.P256Fp](p256.Params().Gy), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul5(t *testing.T) { + assert := test.NewAssert(t) + p384 := elliptic.P384() + s, err := rand.Int(rand.Reader, p384.Params().N) + assert.NoError(err) + px, py := p384.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulTest[emulated.P384Fp, emulated.P384Fr]{} + witness := ScalarMulTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](s), + P: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](p384.Params().Gx), + Y: emulated.ValueOf[emulated.P384Fp](p384.Params().Gy), + }, + Q: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](px), + Y: emulated.ValueOf[emulated.P384Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type IsOnCurveTest[T, S emulated.FieldParams] struct { + Q AffinePoint[T] +} + +func (c *IsOnCurveTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + cr.AssertIsOnCurve(&c.Q) + return nil +} + +func TestIsOnCurve(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var Q, infinity secp256k1.G1Affine + Q.ScalarMultiplication(&g, s) + + // Q=[s]G is on curve + circuit := IsOnCurveTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness1 := IsOnCurveTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](Q.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](Q.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) is on curve + witness2 := IsOnCurveTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](infinity.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestIsOnCurve2(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bn254.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var Q, infinity bn254.G1Affine + Q.ScalarMultiplication(&g, s) + + // Q=[s]G is on curve + circuit := IsOnCurveTest[emulated.BN254Fp, emulated.BN254Fr]{} + witness1 := IsOnCurveTest[emulated.BN254Fp, emulated.BN254Fr]{ + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Q.X), + Y: emulated.ValueOf[emulated.BN254Fp](Q.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) is on curve + witness2 := IsOnCurveTest[emulated.BN254Fp, emulated.BN254Fr]{ + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestIsOnCurve3(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bls12381.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var Q, infinity bls12381.G1Affine + Q.ScalarMultiplication(&g, s) + + // Q=[s]G is on curve + circuit := IsOnCurveTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{} + witness1 := IsOnCurveTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](Q.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](Q.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) is on curve + witness2 := IsOnCurveTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](infinity.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/native/fields_bls12377/doc.go b/std/algebra/native/fields_bls12377/doc.go new file mode 100644 index 0000000000..69308f8280 --- /dev/null +++ b/std/algebra/native/fields_bls12377/doc.go @@ -0,0 +1,9 @@ +// Package fields_bls12377 implements the fields arithmetic of the Fp12 tower +// used to compute the pairing over the BLS12-377 curve. +// +// 𝔽p²[u] = 𝔽p/u²+5 +// 𝔽p⁶[v] = 𝔽p²/v³-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// +// Reference: https://eprint.iacr.org/2022/1162 +package fields_bls12377 diff --git a/std/algebra/fields_bls12377/e12.go b/std/algebra/native/fields_bls12377/e12.go similarity index 90% rename from std/algebra/fields_bls12377/e12.go rename to std/algebra/native/fields_bls12377/e12.go index 64ed8fc024..ed787978d5 100644 --- a/std/algebra/fields_bls12377/e12.go +++ b/std/algebra/native/fields_bls12377/e12.go @@ -20,7 +20,8 @@ import ( "math/big" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -262,6 +263,8 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12) *E12 { // Decompress Karabina's cyclotomic square result func (e *E12) Decompress(api frontend.API, x E12) *E12 { + // TODO: hadle the g3==0 case with MUX + var t [3]E2 var one E2 one.SetOne() @@ -344,51 +347,6 @@ func (e *E12) Conjugate(api frontend.API, e1 E12) *E12 { return e } -// MulBy034 multiplication by sparse element -func (e *E12) MulBy034(api frontend.API, c3, c4 E2) *E12 { - - var d E6 - - a := e.C0 - b := e.C1 - - b.MulBy01(api, c3, c4) - - c3.Add(api, E2{A0: 1, A1: 0}, c3) - d.Add(api, e.C0, e.C1) - d.MulBy01(api, c3, c4) - - e.C1.Add(api, a, b).Neg(api, e.C1).Add(api, e.C1, d) - e.C0.MulByNonResidue(api, b).Add(api, e.C0, a) - - return e -} - -// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) -func (e *E12) Mul034By034(api frontend.API, d3, d4, c3, c4 E2) *E12 { - var one, tmp, x3, x4, x04, x03, x34 E2 - one.SetOne() - x3.Mul(api, c3, d3) - x4.Mul(api, c4, d4) - x04.Add(api, c4, d4) - x03.Add(api, c3, d3) - tmp.Add(api, c3, c4) - x34.Add(api, d3, d4). - Mul(api, x34, tmp). - Sub(api, x34, x3). - Sub(api, x34, x4) - - e.C0.B0.MulByNonResidue(api, x4). - Add(api, e.C0.B0, one) - e.C0.B1 = x3 - e.C0.B2 = x34 - e.C1.B0 = x03 - e.C1.B1 = x04 - e.C1.B2.SetZero() - - return e -} - // Frobenius applies frob to an fp12 elmt func (e *E12) Frobenius(api frontend.API, e1 E12) *E12 { @@ -464,7 +422,7 @@ var InverseE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE12Hint) + solver.RegisterHint(InverseE12Hint) } // Inverse e12 elmts @@ -537,7 +495,7 @@ var DivE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE12Hint) + solver.RegisterHint(DivE12Hint) } // DivUnchecked e12 elmts @@ -587,35 +545,6 @@ func (e *E12) nSquareCompressed(api frontend.API, n int) { } } -// Expt compute e1**exponent, where the exponent is hardcoded -// This function is only used for the final expo of the pairing for bls12377, so the exponent is supposed to be hardcoded -// and on 64 bits. -func (e *E12) Expt(api frontend.API, e1 E12, exponent uint64) *E12 { - - res := e1 - - res.nSquareCompressed(api, 5) - res.Decompress(api, res) - res.Mul(api, res, e1) - x33 := res - res.nSquareCompressed(api, 7) - res.Decompress(api, res) - res.Mul(api, res, x33) - res.nSquareCompressed(api, 4) - res.Decompress(api, res) - res.Mul(api, res, e1) - res.CyclotomicSquare(api, res) - res.Mul(api, res, e1) - res.nSquareCompressed(api, 46) - res.Decompress(api, res) - res.Mul(api, res, e1) - - *e = res - - return e - -} - // Assign a value to self (witness assignment) func (e *E12) Assign(a *bls12377.E12) { e.C0.Assign(&a.C0) diff --git a/std/algebra/native/fields_bls12377/e12_pairing.go b/std/algebra/native/fields_bls12377/e12_pairing.go new file mode 100644 index 0000000000..6cb7242f14 --- /dev/null +++ b/std/algebra/native/fields_bls12377/e12_pairing.go @@ -0,0 +1,115 @@ +package fields_bls12377 + +import "github.com/consensys/gnark/frontend" + +// MulBy034 multiplication by sparse element +func (e *E12) MulBy034(api frontend.API, c3, c4 E2) *E12 { + + var d E6 + + a := e.C0 + b := e.C1 + + b.MulBy01(api, c3, c4) + + c3.Add(api, E2{A0: 1, A1: 0}, c3) + d.Add(api, e.C0, e.C1) + d.MulBy01(api, c3, c4) + + e.C1.Add(api, a, b).Neg(api, e.C1).Add(api, e.C1, d) + e.C0.MulByNonResidue(api, b).Add(api, e.C0, a) + + return e +} + +// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) +func Mul034By034(api frontend.API, d3, d4, c3, c4 E2) *[5]E2 { + var one, tmp, x00, x3, x4, x04, x03, x34 E2 + one.SetOne() + x3.Mul(api, c3, d3) + x4.Mul(api, c4, d4) + x04.Add(api, c4, d4) + x03.Add(api, c3, d3) + tmp.Add(api, c3, c4) + x34.Add(api, d3, d4). + Mul(api, x34, tmp). + Sub(api, x34, x3). + Sub(api, x34, x4) + + x00.MulByNonResidue(api, x4). + Add(api, x00, one) + + return &[5]E2{x00, x3, x34, x03, x04} +} + +func Mul01234By034(api frontend.API, x [5]E2, z3, z4 E2) *E12 { + var a, b, z1, z0, one E6 + var zero E2 + zero.SetZero() + one.SetOne() + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: zero} + a.Add(api, one, E6{B0: z3, B1: z4, B2: zero}) + b.Add(api, *c0, *c1) + a.Mul(api, a, b) + c := *Mul01By01(api, z3, z4, x[3], x[4]) + z1.Sub(api, a, *c0) + z1.Sub(api, z1, c) + z0.MulByNonResidue(api, c) + z0.Add(api, z0, *c0) + return &E12{ + C0: z0, + C1: z1, + } +} + +func (e *E12) MulBy01234(api frontend.API, x [5]E2) *E12 { + var a, b, c, z1, z0 E6 + var zero E2 + zero.SetZero() + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: zero} + a.Add(api, e.C0, e.C1) + b.Add(api, *c0, *c1) + a.Mul(api, a, b) + b.Mul(api, e.C0, *c0) + c = e.C1 + c.MulBy01(api, x[3], x[4]) + z1.Sub(api, a, b) + z1.Sub(api, z1, c) + z0.MulByNonResidue(api, c) + z0.Add(api, z0, b) + + e.C0 = z0 + e.C1 = z1 + return e +} + +// Expt compute e1**exponent, where the exponent is hardcoded +// This function is only used for the final expo of the pairing for bls12377, so the exponent is supposed to be hardcoded +// and on 64 bits. +func (e *E12) Expt(api frontend.API, e1 E12, exponent uint64) *E12 { + + res := e1 + + res.nSquareCompressed(api, 5) + res.Decompress(api, res) + res.Mul(api, res, e1) + x33 := res + res.nSquareCompressed(api, 7) + res.Decompress(api, res) + res.Mul(api, res, x33) + res.nSquareCompressed(api, 4) + res.Decompress(api, res) + res.Mul(api, res, e1) + res.CyclotomicSquare(api, res) + res.Mul(api, res, e1) + res.nSquareCompressed(api, 46) + res.Decompress(api, res) + res.Mul(api, res, e1) + + *e = res + + return e + +} diff --git a/std/algebra/fields_bls12377/e12_test.go b/std/algebra/native/fields_bls12377/e12_test.go similarity index 100% rename from std/algebra/fields_bls12377/e12_test.go rename to std/algebra/native/fields_bls12377/e12_test.go diff --git a/std/algebra/fields_bls12377/e2.go b/std/algebra/native/fields_bls12377/e2.go similarity index 90% rename from std/algebra/fields_bls12377/e2.go rename to std/algebra/native/fields_bls12377/e2.go index 067e8242c4..fbb282fae8 100644 --- a/std/algebra/fields_bls12377/e2.go +++ b/std/algebra/native/fields_bls12377/e2.go @@ -21,7 +21,8 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -154,7 +155,7 @@ var InverseE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE2Hint) + solver.RegisterHint(InverseE2Hint) } // Inverse e2 elmts @@ -196,7 +197,7 @@ var DivE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE2Hint) + solver.RegisterHint(DivE2Hint) } // DivUnchecked e2 elmts @@ -240,3 +241,16 @@ func (e *E2) Select(api frontend.API, b frontend.Variable, r1, r2 E2) *E2 { return e } + +// Lookup2 implements two-bit lookup. It returns: +// - r1 if b1=0 and b2=0, +// - r2 if b1=0 and b2=1, +// - r3 if b1=1 and b2=0, +// - r3 if b1=1 and b2=1. +func (e *E2) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E2) *E2 { + + e.A0 = api.Lookup2(b1, b2, r1.A0, r2.A0, r3.A0, r4.A0) + e.A1 = api.Lookup2(b1, b2, r1.A1, r2.A1, r3.A1, r4.A1) + + return e +} diff --git a/std/algebra/fields_bls12377/e2_test.go b/std/algebra/native/fields_bls12377/e2_test.go similarity index 100% rename from std/algebra/fields_bls12377/e2_test.go rename to std/algebra/native/fields_bls12377/e2_test.go diff --git a/std/algebra/fields_bls12377/e6.go b/std/algebra/native/fields_bls12377/e6.go similarity index 92% rename from std/algebra/fields_bls12377/e6.go rename to std/algebra/native/fields_bls12377/e6.go index b8d15605e8..79a61d1768 100644 --- a/std/algebra/fields_bls12377/e6.go +++ b/std/algebra/native/fields_bls12377/e6.go @@ -20,7 +20,8 @@ import ( "math/big" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -198,7 +199,7 @@ var DivE6Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE6Hint) + solver.RegisterHint(DivE6Hint) } // DivUnchecked e6 elmts @@ -246,7 +247,7 @@ var InverseE6Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE6Hint) + solver.RegisterHint(InverseE6Hint) } // Inverse e6 elmts @@ -324,3 +325,28 @@ func (e *E6) MulBy01(api frontend.API, c0, c1 E2) *E6 { return e } + +func Mul01By01(api frontend.API, c0, c1, d0, d1 E2) *E6 { + var a, b, t0, t1, t2, tmp E2 + + a.Mul(api, d0, c0) + b.Mul(api, d1, c1) + t0.Mul(api, c1, d1) + t0.Sub(api, t0, b) + t0.MulByNonResidue(api, t0) + t0.Add(api, t0, a) + t2.Mul(api, c0, d0) + t2.Sub(api, t2, a) + t2.Add(api, t2, b) + t1.Add(api, c0, c1) + tmp.Add(api, d0, d1) + t1.Mul(api, t1, tmp) + t1.Sub(api, t1, a) + t1.Sub(api, t1, b) + + return &E6{ + B0: t0, + B1: t1, + B2: t2, + } +} diff --git a/std/algebra/fields_bls12377/e6_test.go b/std/algebra/native/fields_bls12377/e6_test.go similarity index 100% rename from std/algebra/fields_bls12377/e6_test.go rename to std/algebra/native/fields_bls12377/e6_test.go diff --git a/std/algebra/native/fields_bls24315/doc.go b/std/algebra/native/fields_bls24315/doc.go new file mode 100644 index 0000000000..8b669480b7 --- /dev/null +++ b/std/algebra/native/fields_bls24315/doc.go @@ -0,0 +1,10 @@ +// Package fields_bls24315 implements the fields arithmetic of the Fp24 tower +// used to compute the pairing over the BLS24-315 curve. +// +// 𝔽p²[u] = 𝔽p/u²-13 +// 𝔽p⁴[v] = 𝔽p²/v²-u +// 𝔽p¹²[w] = 𝔽p⁴/w³-v +// 𝔽p²⁴[i] = 𝔽p¹²/i²-w +// +// Reference: https://eprint.iacr.org/2022/1162 +package fields_bls24315 diff --git a/std/algebra/fields_bls24315/e12.go b/std/algebra/native/fields_bls24315/e12.go similarity index 98% rename from std/algebra/fields_bls24315/e12.go rename to std/algebra/native/fields_bls24315/e12.go index 2a8583f3f8..569602c222 100644 --- a/std/algebra/fields_bls24315/e12.go +++ b/std/algebra/native/fields_bls24315/e12.go @@ -20,7 +20,7 @@ import ( "math/big" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -209,7 +209,7 @@ var InverseE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE12Hint) + solver.RegisterHint(InverseE12Hint) } // Inverse e12 elmts @@ -282,7 +282,7 @@ var DivE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE12Hint) + solver.RegisterHint(DivE12Hint) } // DivUnchecked e12 elmts diff --git a/std/algebra/fields_bls24315/e12_test.go b/std/algebra/native/fields_bls24315/e12_test.go similarity index 100% rename from std/algebra/fields_bls24315/e12_test.go rename to std/algebra/native/fields_bls24315/e12_test.go diff --git a/std/algebra/fields_bls24315/e2.go b/std/algebra/native/fields_bls24315/e2.go similarity index 91% rename from std/algebra/fields_bls24315/e2.go rename to std/algebra/native/fields_bls24315/e2.go index b850ef025b..441f575948 100644 --- a/std/algebra/fields_bls24315/e2.go +++ b/std/algebra/native/fields_bls24315/e2.go @@ -21,7 +21,7 @@ import ( bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/utils" ) @@ -160,7 +160,7 @@ var DivE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE2Hint) + solver.RegisterHint(DivE2Hint) } // DivUnchecked e2 elmts @@ -199,7 +199,7 @@ var InverseE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE2Hint) + solver.RegisterHint(InverseE2Hint) } // Inverse e2 elmts @@ -244,3 +244,16 @@ func (e *E2) Select(api frontend.API, b frontend.Variable, r1, r2 E2) *E2 { return e } + +// Lookup2 implements two-bit lookup. It returns: +// - r1 if b1=0 and b2=0, +// - r2 if b1=0 and b2=1, +// - r3 if b1=1 and b2=0, +// - r3 if b1=1 and b2=1. +func (e *E2) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E2) *E2 { + + e.A0 = api.Lookup2(b1, b2, r1.A0, r2.A0, r3.A0, r4.A0) + e.A1 = api.Lookup2(b1, b2, r1.A1, r2.A1, r3.A1, r4.A1) + + return e +} diff --git a/std/algebra/fields_bls24315/e24.go b/std/algebra/native/fields_bls24315/e24.go similarity index 91% rename from std/algebra/fields_bls24315/e24.go rename to std/algebra/native/fields_bls24315/e24.go index d70722cb70..ea21cf09c5 100644 --- a/std/algebra/fields_bls24315/e24.go +++ b/std/algebra/native/fields_bls24315/e24.go @@ -20,7 +20,7 @@ import ( "math/big" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -341,53 +341,6 @@ func (e *E24) Conjugate(api frontend.API, e1 E24) *E24 { return e } -// MulBy034 multiplication by sparse element -func (e *E24) MulBy034(api frontend.API, c3, c4 E4) *E24 { - - var d E12 - var one E4 - one.SetOne() - - a := e.D0 - b := e.D1 - - b.MulBy01(api, c3, c4) - - c3.Add(api, one, c3) - d.Add(api, e.D0, e.D1) - d.MulBy01(api, c3, c4) - - e.D1.Add(api, a, b).Neg(api, e.D1).Add(api, e.D1, d) - e.D0.MulByNonResidue(api, b).Add(api, e.D0, a) - - return e -} - -// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) -func (e *E24) Mul034By034(api frontend.API, d3, d4, c3, c4 E4) *E24 { - var one, tmp, x3, x4, x04, x03, x34 E4 - one.SetOne() - x3.Mul(api, c3, d3) - x4.Mul(api, c4, d4) - x04.Add(api, c4, d4) - x03.Add(api, c3, d3) - tmp.Add(api, c3, c4) - x34.Add(api, d3, d4). - Mul(api, x34, tmp). - Sub(api, x34, x3). - Sub(api, x34, x4) - - e.D0.C0.MulByNonResidue(api, x4). - Add(api, e.D0.C0, one) - e.D0.C1 = x3 - e.D0.C2 = x34 - e.D1.C0 = x03 - e.D1.C1 = x04 - e.D1.C2.SetZero() - - return e -} - var InverseE24Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { var a, c bls24315.E24 @@ -447,7 +400,7 @@ var InverseE24Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE24Hint) + solver.RegisterHint(InverseE24Hint) } // Inverse e24 elmts @@ -557,7 +510,7 @@ var DivE24Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE24Hint) + solver.RegisterHint(DivE24Hint) } // DivUnchecked e24 elmts @@ -595,31 +548,6 @@ func (e *E24) nSquare(api frontend.API, n int) { } } -// Expt compute e1**exponent, where the exponent is hardcoded -// This function is only used for the final expo of the pairing for bls24315, so the exponent is supposed to be hardcoded and on 32 bits. -func (e *E24) Expt(api frontend.API, x E24, exponent uint64) *E24 { - - xInv := E24{} - res := x - xInv.Conjugate(api, x) - - res.nSquare(api, 2) - res.Mul(api, res, xInv) - res.nSquareCompressed(api, 8) - res.Decompress(api, res) - res.Mul(api, res, xInv) - res.nSquare(api, 2) - res.Mul(api, res, x) - res.nSquareCompressed(api, 20) - res.Decompress(api, res) - res.Mul(api, res, xInv) - res.Conjugate(api, res) - - *e = res - - return e -} - // AssertIsEqual constraint self to be equal to other into the given constraint system func (e *E24) AssertIsEqual(api frontend.API, other E24) { e.D0.AssertIsEqual(api, other.D0) diff --git a/std/algebra/native/fields_bls24315/e24_pairing.go b/std/algebra/native/fields_bls24315/e24_pairing.go new file mode 100644 index 0000000000..03ce073b5a --- /dev/null +++ b/std/algebra/native/fields_bls24315/e24_pairing.go @@ -0,0 +1,95 @@ +package fields_bls24315 + +import "github.com/consensys/gnark/frontend" + +// MulBy034 multiplication by sparse element +func (e *E24) MulBy034(api frontend.API, c3, c4 E4) *E24 { + + var d E12 + var one E4 + one.SetOne() + + a := e.D0 + b := e.D1 + + b.MulBy01(api, c3, c4) + + c3.Add(api, one, c3) + d.Add(api, e.D0, e.D1) + d.MulBy01(api, c3, c4) + + e.D1.Add(api, a, b).Neg(api, e.D1).Add(api, e.D1, d) + e.D0.MulByNonResidue(api, b).Add(api, e.D0, a) + + return e +} + +// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) +func (e *E24) Mul034By034(api frontend.API, d3, d4, c3, c4 E4) *E24 { + var one, tmp, x3, x4, x04, x03, x34 E4 + one.SetOne() + x3.Mul(api, c3, d3) + x4.Mul(api, c4, d4) + x04.Add(api, c4, d4) + x03.Add(api, c3, d3) + tmp.Add(api, c3, c4) + x34.Add(api, d3, d4). + Mul(api, x34, tmp). + Sub(api, x34, x3). + Sub(api, x34, x4) + + e.D0.C0.MulByNonResidue(api, x4). + Add(api, e.D0.C0, one) + e.D0.C1 = x3 + e.D0.C2 = x34 + e.D1.C0 = x03 + e.D1.C1 = x04 + e.D1.C2.SetZero() + + return e +} + +// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) +func Mul034By034(api frontend.API, d3, d4, c3, c4 E4) *[5]E4 { + var one, tmp, x00, x3, x4, x04, x03, x34 E4 + one.SetOne() + x3.Mul(api, c3, d3) + x4.Mul(api, c4, d4) + x04.Add(api, c4, d4) + x03.Add(api, c3, d3) + tmp.Add(api, c3, c4) + x34.Add(api, d3, d4). + Mul(api, x34, tmp). + Sub(api, x34, x3). + Sub(api, x34, x4) + + x00.MulByNonResidue(api, x4). + Add(api, x00, one) + + return &[5]E4{x00, x3, x34, x03, x04} +} + +// Expt compute e1**exponent, where the exponent is hardcoded +// This function is only used for the final expo of the pairing for bls24315, so the exponent is supposed to be hardcoded and on 32 bits. +func (e *E24) Expt(api frontend.API, x E24, exponent uint64) *E24 { + + xInv := E24{} + res := x + xInv.Conjugate(api, x) + + res.nSquare(api, 2) + res.Mul(api, res, xInv) + res.nSquareCompressed(api, 8) + res.Decompress(api, res) + res.Mul(api, res, xInv) + res.nSquare(api, 2) + res.Mul(api, res, x) + res.nSquareCompressed(api, 20) + res.Decompress(api, res) + res.Mul(api, res, xInv) + res.Conjugate(api, res) + + *e = res + + return e +} diff --git a/std/algebra/fields_bls24315/e24_test.go b/std/algebra/native/fields_bls24315/e24_test.go similarity index 100% rename from std/algebra/fields_bls24315/e24_test.go rename to std/algebra/native/fields_bls24315/e24_test.go diff --git a/std/algebra/fields_bls24315/e2_test.go b/std/algebra/native/fields_bls24315/e2_test.go similarity index 100% rename from std/algebra/fields_bls24315/e2_test.go rename to std/algebra/native/fields_bls24315/e2_test.go diff --git a/std/algebra/fields_bls24315/e4.go b/std/algebra/native/fields_bls24315/e4.go similarity index 91% rename from std/algebra/fields_bls24315/e4.go rename to std/algebra/native/fields_bls24315/e4.go index 57abb6ccf8..f6fc78c405 100644 --- a/std/algebra/fields_bls24315/e4.go +++ b/std/algebra/native/fields_bls24315/e4.go @@ -20,7 +20,7 @@ import ( "math/big" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -165,7 +165,7 @@ var DivE4Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE4Hint) + solver.RegisterHint(DivE4Hint) } // DivUnchecked e4 elmts @@ -208,7 +208,7 @@ var InverseE4Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE4Hint) + solver.RegisterHint(InverseE4Hint) } // Inverse e4 elmts @@ -253,3 +253,16 @@ func (e *E4) Select(api frontend.API, b frontend.Variable, r1, r2 E4) *E4 { return e } + +// Lookup2 implements two-bit lookup. It returns: +// - r1 if b1=0 and b2=0, +// - r2 if b1=0 and b2=1, +// - r3 if b1=1 and b2=0, +// - r3 if b1=1 and b2=1. +func (e *E4) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E4) *E4 { + + e.B0.Lookup2(api, b1, b2, r1.B0, r2.B0, r3.B0, r4.B0) + e.B1.Lookup2(api, b1, b2, r1.B1, r2.B1, r3.B1, r4.B1) + + return e +} diff --git a/std/algebra/fields_bls24315/e4_test.go b/std/algebra/native/fields_bls24315/e4_test.go similarity index 100% rename from std/algebra/fields_bls24315/e4_test.go rename to std/algebra/native/fields_bls24315/e4_test.go diff --git a/std/algebra/native/sw_bls12377/doc.go b/std/algebra/native/sw_bls12377/doc.go new file mode 100644 index 0000000000..3888dba835 --- /dev/null +++ b/std/algebra/native/sw_bls12377/doc.go @@ -0,0 +1,8 @@ +// Package sw_bls12377 implements the arithmetics of G1, G2 and the pairing +// computation on BLS12-377 as a SNARK circuit over BW6-761. These two curves +// form a 2-chain so the operations use native field arithmetic. +// +// References: +// BW6-761: https://eprint.iacr.org/2020/351 +// Pairings in R1CS: https://eprint.iacr.org/2022/1162 +package sw_bls12377 diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/native/sw_bls12377/g1.go similarity index 92% rename from std/algebra/sw_bls12377/g1.go rename to std/algebra/native/sw_bls12377/g1.go index 303315a833..546fd19857 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/native/sw_bls12377/g1.go @@ -22,7 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -228,7 +229,7 @@ var DecomposeScalarG1 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG1) + solver.RegisterHint(DecomposeScalarG1) } // varScalarMul sets P = [s] Q and returns P. @@ -299,20 +300,18 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -451,3 +450,36 @@ func (p *G1Affine) DoubleAndAdd(api frontend.API, p1, p2 *G1Affine) *G1Affine { return p } + +// ScalarMulBase computes s * g1 and returns it, where g1 is the fixed generator. It doesn't modify s. +func (P *G1Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G1Affine { + + points := getCurvePoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G1Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X = api.Lookup2(sBits[1], sBits[2], points.G1x, points.G1m[0][0], points.G1m[1][0], points.G1m[2][0]) + res.Y = api.Lookup2(sBits[1], sBits[2], points.G1y, points.G1m[0][1], points.G1m[1][1], points.G1m[2][1]) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G1Affine{points.G1m[i][0], points.G1m[i][1]}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G1Affine{points.G1x, points.G1y}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls12377/g1_test.go b/std/algebra/native/sw_bls12377/g1_test.go similarity index 93% rename from std/algebra/sw_bls12377/g1_test.go rename to std/algebra/native/sw_bls12377/g1_test.go index 2a0a83a1a5..55e5c524bc 100644 --- a/std/algebra/sw_bls12377/g1_test.go +++ b/std/algebra/native/sw_bls12377/g1_test.go @@ -369,6 +369,37 @@ func TestScalarMulG1(t *testing.T) { assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) } +type g1varScalarMulBase struct { + C G1Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g1varScalarMulBase) Define(api frontend.API) error { + expected := G1Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG1(t *testing.T) { + var c bls12377.G1Affine + gJac, _, _, _ := bls12377.Generators() + + // create the cs + var circuit, witness g1varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) +} + func randomPointG1() bls12377.G1Jac { p1, _, _, _ := bls12377.Generators() diff --git a/std/algebra/sw_bls12377/g2.go b/std/algebra/native/sw_bls12377/g2.go similarity index 87% rename from std/algebra/sw_bls12377/g2.go rename to std/algebra/native/sw_bls12377/g2.go index fcfcdbcf20..64fdf9e87f 100644 --- a/std/algebra/sw_bls12377/g2.go +++ b/std/algebra/native/sw_bls12377/g2.go @@ -21,9 +21,10 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" ) // G2Jac point in Jacobian coords @@ -243,7 +244,7 @@ var DecomposeScalarG2 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG2) + solver.RegisterHint(DecomposeScalarG2) } // varScalarMul sets P = [s] Q and returns P. @@ -314,20 +315,18 @@ func (P *G2Affine) varScalarMul(api frontend.API, Q G2Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -471,3 +470,68 @@ func (p *G2Affine) DoubleAndAdd(api frontend.API, p1, p2 *G2Affine) *G2Affine { return p } + +// ScalarMulBase computes s * g2 and returns it, where g2 is the fixed generator. It doesn't modify s. +func (P *G2Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G2Affine { + + points := getTwistPoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G2Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X.Lookup2(api, sBits[1], sBits[2], + fields_bls12377.E2{ + A0: points.G2x[0], + A1: points.G2x[1]}, + fields_bls12377.E2{ + A0: points.G2m[0][0], + A1: points.G2m[0][1]}, + fields_bls12377.E2{ + A0: points.G2m[1][0], + A1: points.G2m[1][1]}, + fields_bls12377.E2{ + A0: points.G2m[2][0], + A1: points.G2m[2][1]}) + res.Y.Lookup2(api, sBits[1], sBits[2], + fields_bls12377.E2{ + A0: points.G2y[0], + A1: points.G2y[1]}, + fields_bls12377.E2{ + A0: points.G2m[0][2], + A1: points.G2m[0][3]}, + fields_bls12377.E2{ + A0: points.G2m[1][2], + A1: points.G2m[1][3]}, + fields_bls12377.E2{ + A0: points.G2m[2][2], + A1: points.G2m[2][3]}) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G2Affine{ + fields_bls12377.E2{ + A0: points.G2m[i][0], + A1: points.G2m[i][1]}, + fields_bls12377.E2{ + A0: points.G2m[i][2], + A1: points.G2m[i][3]}}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G2Affine{ + fields_bls12377.E2{A0: points.G2x[0], A1: points.G2x[1]}, + fields_bls12377.E2{A0: points.G2y[0], A1: points.G2y[1]}}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls12377/g2_test.go b/std/algebra/native/sw_bls12377/g2_test.go similarity index 93% rename from std/algebra/sw_bls12377/g2_test.go rename to std/algebra/native/sw_bls12377/g2_test.go index 534a58b54c..2bca05f714 100644 --- a/std/algebra/sw_bls12377/g2_test.go +++ b/std/algebra/native/sw_bls12377/g2_test.go @@ -374,6 +374,38 @@ func TestScalarMulG2(t *testing.T) { assert := test.NewAssert(t) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) } + +type g2varScalarMulBase struct { + C G2Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g2varScalarMulBase) Define(api frontend.API) error { + expected := G2Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG2(t *testing.T) { + var c bls12377.G2Affine + _, gJac, _, _ := bls12377.Generators() + + // create the cs + var circuit, witness g2varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) +} + func randomPointG2() bls12377.G2Jac { _, p2, _, _ := bls12377.Generators() diff --git a/std/algebra/sw_bls12377/inner.go b/std/algebra/native/sw_bls12377/inner.go similarity index 66% rename from std/algebra/sw_bls12377/inner.go rename to std/algebra/native/sw_bls12377/inner.go index 7843ac8f17..3396afa7d1 100644 --- a/std/algebra/sw_bls12377/inner.go +++ b/std/algebra/native/sw_bls12377/inner.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/frontend" ) @@ -62,3 +63,44 @@ func getInnerCurveConfig(outerCurveScalarField *big.Int) *innerConfig { return &innerConfigBW6_761 } + +var ( + computedCurveTable [][2]*big.Int + computedTwistTable [][4]*big.Int +) + +func init() { + computedCurveTable = computeCurveTable() + computedTwistTable = computeTwistTable() +} + +type curvePoints struct { + G1x *big.Int // base point x + G1y *big.Int // base point y + G1m [][2]*big.Int // m*base points (x,y) +} + +func getCurvePoints() curvePoints { + _, _, g1aff, _ := bls12377.Generators() + return curvePoints{ + G1x: g1aff.X.BigInt(new(big.Int)), + G1y: g1aff.Y.BigInt(new(big.Int)), + G1m: computedCurveTable, + } +} + +type twistPoints struct { + G2x [2]*big.Int // base point x ∈ E2 + G2y [2]*big.Int // base point y ∈ E2 + G2m [][4]*big.Int // m*base points (x,y) +} + +func getTwistPoints() twistPoints { + _, _, _, g2aff := bls12377.Generators() + return twistPoints{ + G2x: [2]*big.Int{g2aff.X.A0.BigInt(new(big.Int)), g2aff.X.A1.BigInt(new(big.Int))}, + G2y: [2]*big.Int{g2aff.Y.A0.BigInt(new(big.Int)), g2aff.Y.A1.BigInt(new(big.Int))}, + G2m: computedTwistTable, + } + +} diff --git a/std/algebra/native/sw_bls12377/inner_compute.go b/std/algebra/native/sw_bls12377/inner_compute.go new file mode 100644 index 0000000000..1bf12a2f5c --- /dev/null +++ b/std/algebra/native/sw_bls12377/inner_compute.go @@ -0,0 +1,59 @@ +package sw_bls12377 + +import ( + "math/big" + + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" +) + +func computeCurveTable() [][2]*big.Int { + G1jac, _, _, _ := bls12377.Generators() + table := make([][2]*big.Int, 253) + tmp := new(bls12377.G1Jac).Set(&G1jac) + aff := new(bls12377.G1Affine) + jac := new(bls12377.G1Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} + +func computeTwistTable() [][4]*big.Int { + _, G2jac, _, _ := bls12377.Generators() + table := make([][4]*big.Int, 253) + tmp := new(bls12377.G2Jac).Set(&G2jac) + aff := new(bls12377.G2Affine) + jac := new(bls12377.G2Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))} + } + } + return table[:] +} diff --git a/std/algebra/native/sw_bls12377/pairing.go b/std/algebra/native/sw_bls12377/pairing.go new file mode 100644 index 0000000000..2f4eb32c16 --- /dev/null +++ b/std/algebra/native/sw_bls12377/pairing.go @@ -0,0 +1,353 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sw_bls12377 + +import ( + "errors" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" +) + +// GT target group of the pairing +type GT = fields_bls12377.E12 + +// binary decomposition of x₀=9586122913090633729 little endian +var loopCounter = [64]int8{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1} + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 + R0(x/y) + R1(1/y) = 0 instead of R0'*y + R1'*x + R2' = 0 This +// makes the multiplication by lines (MulBy034) and between lines (Mul034By034) +// circuit-efficient. +type lineEvaluation struct { + R0, R1 fields_bls12377.E2 +} + +// MillerLoop computes the product of n miller loops (n can be 1) +// ∏ᵢ { fᵢ_{x₀,Q}(P) } +func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return GT{}, errors.New("invalid inputs sizes") + } + + var res GT + res.SetOne() + var prodLines [5]fields_bls12377.E2 + + var l1, l2 lineEvaluation + Qacc := make([]G2Affine, n) + yInv := make([]frontend.Variable, n) + xOverY := make([]frontend.Variable, n) + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + // x=0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000000 + // TODO: point P=(x,0) should be ruled out + yInv[k] = api.DivUnchecked(1, P[k].Y) + xOverY[k] = api.Mul(P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + // i = 64, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line to res) + Qacc[0], l1 = doubleStep(api, &Qacc[0]) + // line evaluation at P[0] + res.C1.B0.MulByFp(api, l1.R0, xOverY[0]) + res.C1.B1.MulByFp(api, l1.R1, yInv[0]) + + if n >= 2 { + // k = 1, separately to avoid MulBy034 (res × ℓ) + // (res is also a line at this point, so we use Mul034By034 ℓ × ℓ) + Qacc[1], l1 = doubleStep(api, &Qacc[1]) + + // line evaluation at P[1] + l1.R0.MulByFp(api, l1.R0, xOverY[1]) + l1.R1.MulByFp(api, l1.R1, yInv[1]) + + // ℓ × res + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, res.C1.B0, res.C1.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B0 = prodLines[3] + res.C1.B1 = prodLines[4] + + } + + if n >= 3 { + // k = 2, separately to avoid MulBy034 (res × ℓ) + // (res has a zero E2 element, so we use Mul01234By034) + Qacc[2], l1 = doubleStep(api, &Qacc[2]) + + // line evaluation at P[1] + l1.R0.MulByFp(api, l1.R0, xOverY[2]) + l1.R1.MulByFp(api, l1.R1, yInv[2]) + + // ℓ × res + res = *fields_bls12377.Mul01234By034(api, prodLines, l1.R0, l1.R1) + + // k >= 3 + for k := 3; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + } + + for i := 61; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res.Square(api, res) + + if loopCounter[i] == 0 { + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + continue + } + + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = doubleAndAddStep(api, &Qacc[k], &Q[k]) + + // lines evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, l2.R0, l2.R1) + // (ℓ × ℓ) × res + res.MulBy01234(api, prodLines) + } + } + + // i = 0 + res.Square(api, res) + for k := 0; k < n; k++ { + // l1 line through Qacc[k] and Q[k] + // l2 line through Qacc[k]+Q[k] and Qacc[k] + l1, l2 = linesCompute(api, &Qacc[k], &Q[k]) + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, l2.R0, l2.R1) + // (ℓ × ℓ) × res + res.MulBy01234(api, prodLines) + } + + return res, nil +} + +// FinalExponentiation computes the exponentiation e1ᵈ +// where d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// we use instead d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// where s is the cofactor 3 (Hayashida et al.) +func FinalExponentiation(api frontend.API, e1 GT) GT { + const genT = 9586122913090633729 + + result := e1 + + // https://eprint.iacr.org/2016/130.pdf + var t [3]GT + + // easy part + // (p⁶-1)(p²+1) + t[0].Conjugate(api, result) + t[0].DivUnchecked(api, t[0], result) + result.FrobeniusSquare(api, t[0]). + Mul(api, result, t[0]) + + // hard part (up to permutation) + // Daiki Hayashida and Kenichiro Hayasaka + // and Tadanori Teruya + // https://eprint.iacr.org/2020/875.pdf + t[0].CyclotomicSquare(api, result) + t[1].Expt(api, result, genT) + t[2].Conjugate(api, result) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Conjugate(api, t[1]) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Frobenius(api, t[1]) + t[1].Mul(api, t[1], t[2]) + result.Mul(api, result, t[0]) + t[0].Expt(api, t[1], genT) + t[2].Expt(api, t[0], genT) + t[0].FrobeniusSquare(api, t[1]) + t[1].Conjugate(api, t[1]) + t[1].Mul(api, t[1], t[2]) + t[1].Mul(api, t[1], t[0]) + result.Mul(api, result, t[1]) + + return result +} + +// Pair calculates the reduced pairing for a set of points +// ∏ᵢ e(Pᵢ, Qᵢ). +// +// This function doesn't check that the inputs are in the correct subgroup. See IsInSubGroup. +func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + f, err := MillerLoop(api, P, Q) + if err != nil { + return GT{}, err + } + return FinalExponentiation(api, f), nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3, x4, y4 fields_bls12377.E2 + var line1, line2 lineEvaluation + var p G2Affine + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute x4 = lambda2**2-x1-x3 + x4.Square(api, l2). + Sub(api, x4, p1.X). + Sub(api, x4, x3) + + // compute y4 = lambda2*(x1 - x4)-y1 + y4.Sub(api, p1.X, x4). + Mul(api, l2, y4). + Sub(api, y4, p1.Y) + + p.X = x4 + p.Y = y4 + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return p, line1, line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleStep(api frontend.API, p1 *G2Affine) (G2Affine, lineEvaluation) { + + var n, d, l, xr, yr fields_bls12377.E2 + var p G2Affine + var line lineEvaluation + + // lambda = 3*p1.x**2/2*p.y + n.Square(api, p1.X).MulByFp(api, n, 3) + d.MulByFp(api, p1.Y, 2) + l.DivUnchecked(api, n, d) + + // xr = lambda**2-2*p1.x + xr.Square(api, l). + Sub(api, xr, p1.X). + Sub(api, xr, p1.X) + + // yr = lambda*(p.x-xr)-p.y + yr.Sub(api, p1.X, xr). + Mul(api, l, yr). + Sub(api, yr, p1.Y) + + p.X = xr + p.Y = yr + + line.R0.Neg(api, l) + line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) + + return p, line + +} + +// linesCompute computes the lines that goes through p1 and p2, and (p1+p2) and p1 but does not compute 2p1+p2 +func linesCompute(api frontend.API, p1, p2 *G2Affine) (lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3 fields_bls12377.E2 + var line1, line2 lineEvaluation + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return line1, line2 +} diff --git a/std/algebra/sw_bls12377/pairing_test.go b/std/algebra/native/sw_bls12377/pairing_test.go similarity index 98% rename from std/algebra/sw_bls12377/pairing_test.go rename to std/algebra/native/sw_bls12377/pairing_test.go index bf09b6efad..dbe5639919 100644 --- a/std/algebra/sw_bls12377/pairing_test.go +++ b/std/algebra/native/sw_bls12377/pairing_test.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" "github.com/consensys/gnark/test" ) diff --git a/std/algebra/native/sw_bls24315/doc.go b/std/algebra/native/sw_bls24315/doc.go new file mode 100644 index 0000000000..9a13dd4769 --- /dev/null +++ b/std/algebra/native/sw_bls24315/doc.go @@ -0,0 +1,8 @@ +// Package sw_bls24315 implements the arithmetics of G1, G2 and the pairing +// computation on BLS24-315 as a SNARK circuit over BW6-633. These two curves +// form a 2-chain so the operations use native field arithmetic. +// +// References: +// BLS24-315/BW6-633: https://eprint.iacr.org/2021/1359 +// Pairings in R1CS: https://eprint.iacr.org/2022/1162 +package sw_bls24315 diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/native/sw_bls24315/g1.go similarity index 92% rename from std/algebra/sw_bls24315/g1.go rename to std/algebra/native/sw_bls24315/g1.go index 1e8b12fd0e..61cda042ff 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/native/sw_bls24315/g1.go @@ -22,7 +22,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -228,7 +228,7 @@ var DecomposeScalarG1 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG1) + solver.RegisterHint(DecomposeScalarG1) } // varScalarMul sets P = [s] Q and returns P. @@ -299,20 +299,18 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -451,3 +449,36 @@ func (p *G1Affine) DoubleAndAdd(api frontend.API, p1, p2 *G1Affine) *G1Affine { return p } + +// ScalarMulBase computes s * g1 and returns it, where g1 is the fixed generator. It doesn't modify s. +func (P *G1Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G1Affine { + + points := getCurvePoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G1Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X = api.Lookup2(sBits[1], sBits[2], points.G1x, points.G1m[0][0], points.G1m[1][0], points.G1m[2][0]) + res.Y = api.Lookup2(sBits[1], sBits[2], points.G1y, points.G1m[0][1], points.G1m[1][1], points.G1m[2][1]) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G1Affine{points.G1m[i][0], points.G1m[i][1]}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G1Affine{points.G1x, points.G1y}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls24315/g1_test.go b/std/algebra/native/sw_bls24315/g1_test.go similarity index 93% rename from std/algebra/sw_bls24315/g1_test.go rename to std/algebra/native/sw_bls24315/g1_test.go index 86feeeb571..c4e02779d0 100644 --- a/std/algebra/sw_bls24315/g1_test.go +++ b/std/algebra/native/sw_bls24315/g1_test.go @@ -369,6 +369,37 @@ func TestScalarMulG1(t *testing.T) { assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) } +type g1varScalarMulBase struct { + C G1Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g1varScalarMulBase) Define(api frontend.API) error { + expected := G1Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG1(t *testing.T) { + var c bls24315.G1Affine + gJac, _, _, _ := bls24315.Generators() + + // create the cs + var circuit, witness g1varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) +} + func randomPointG1() bls24315.G1Jac { p1, _, _, _ := bls24315.Generators() diff --git a/std/algebra/sw_bls24315/g2.go b/std/algebra/native/sw_bls24315/g2.go similarity index 81% rename from std/algebra/sw_bls24315/g2.go rename to std/algebra/native/sw_bls24315/g2.go index 001fe45f5d..cd5dc9d064 100644 --- a/std/algebra/sw_bls24315/g2.go +++ b/std/algebra/native/sw_bls24315/g2.go @@ -21,9 +21,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" ) // G2Jac point in Jacobian coords @@ -243,7 +243,7 @@ var DecomposeScalarG2 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG2) + solver.RegisterHint(DecomposeScalarG2) } // varScalarMul sets P = [s] Q and returns P. @@ -314,20 +314,18 @@ func (P *G2Affine) varScalarMul(api frontend.API, Q G2Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -471,3 +469,73 @@ func (p *G2Affine) DoubleAndAdd(api frontend.API, p1, p2 *G2Affine) *G2Affine { return p } + +// ScalarMulBase computes s * g2 and returns it, where g2 is the fixed generator. It doesn't modify s. +func (P *G2Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G2Affine { + + points := getTwistPoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G2Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X.Lookup2(api, sBits[1], sBits[2], + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2x[0], A1: points.G2x[1]}, + B1: fields_bls24315.E2{A0: points.G2x[2], A1: points.G2x[3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[0][0], A1: points.G2m[0][1]}, + B1: fields_bls24315.E2{A0: points.G2m[0][2], A1: points.G2m[0][3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[1][0], A1: points.G2m[1][1]}, + B1: fields_bls24315.E2{A0: points.G2m[1][2], A1: points.G2m[1][3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[2][0], A1: points.G2m[2][1]}, + B1: fields_bls24315.E2{A0: points.G2m[2][2], A1: points.G2m[2][3]}}) + + res.Y.Lookup2(api, sBits[1], sBits[2], + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2y[0], A1: points.G2y[1]}, + B1: fields_bls24315.E2{A0: points.G2y[2], A1: points.G2y[3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[0][4], A1: points.G2m[0][5]}, + B1: fields_bls24315.E2{A0: points.G2m[0][6], A1: points.G2m[0][7]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[1][4], A1: points.G2m[1][5]}, + B1: fields_bls24315.E2{A0: points.G2m[1][6], A1: points.G2m[1][7]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[2][4], A1: points.G2m[2][5]}, + B1: fields_bls24315.E2{A0: points.G2m[2][6], A1: points.G2m[2][7]}}) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G2Affine{ + X: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[i][0], A1: points.G2m[i][1]}, + B1: fields_bls24315.E2{A0: points.G2m[i][2], A1: points.G2m[i][3]}}, + Y: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[i][4], A1: points.G2m[i][5]}, + B1: fields_bls24315.E2{A0: points.G2m[i][6], A1: points.G2m[i][7]}}}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G2Affine{ + X: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2x[0], A1: points.G2x[1]}, + B1: fields_bls24315.E2{A0: points.G2x[2], A1: points.G2x[3]}}, + Y: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2y[0], A1: points.G2y[1]}, + B1: fields_bls24315.E2{A0: points.G2y[2], A1: points.G2y[3]}}}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls24315/g2_test.go b/std/algebra/native/sw_bls24315/g2_test.go similarity index 93% rename from std/algebra/sw_bls24315/g2_test.go rename to std/algebra/native/sw_bls24315/g2_test.go index 0c8fb3ad20..13034f058e 100644 --- a/std/algebra/sw_bls24315/g2_test.go +++ b/std/algebra/native/sw_bls24315/g2_test.go @@ -374,6 +374,38 @@ func TestScalarMulG2(t *testing.T) { assert := test.NewAssert(t) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) } + +type g2varScalarMulBase struct { + C G2Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g2varScalarMulBase) Define(api frontend.API) error { + expected := G2Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG2(t *testing.T) { + var c bls24315.G2Affine + _, gJac, _, _ := bls24315.Generators() + + // create the cs + var circuit, witness g2varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) +} + func randomPointG2() bls24315.G2Jac { _, p2, _, _ := bls24315.Generators() diff --git a/std/algebra/sw_bls24315/inner.go b/std/algebra/native/sw_bls24315/inner.go similarity index 63% rename from std/algebra/sw_bls24315/inner.go rename to std/algebra/native/sw_bls24315/inner.go index 42b6355cd0..8630c645ed 100644 --- a/std/algebra/sw_bls24315/inner.go +++ b/std/algebra/native/sw_bls24315/inner.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark/frontend" ) @@ -65,3 +66,54 @@ func (cc *innerConfig) phi2(api frontend.API, res, P *G2Affine) *G2Affine { res.Y = P.Y return res } + +type curvePoints struct { + G1x *big.Int // base point x + G1y *big.Int // base point y + G1m [][2]*big.Int // m*base points (x,y) +} + +var ( + computedCurveTable [][2]*big.Int + computedTwistTable [][8]*big.Int +) + +func init() { + computedCurveTable = computeCurveTable() + computedTwistTable = computeTwistTable() +} + +func getCurvePoints() curvePoints { + _, _, g1aff, _ := bls24315.Generators() + return curvePoints{ + G1x: g1aff.X.BigInt(new(big.Int)), + G1y: g1aff.Y.BigInt(new(big.Int)), + G1m: computedCurveTable, + } +} + +type twistPoints struct { + G2x [4]*big.Int // base point x ∈ E4 + G2y [4]*big.Int // base point y ∈ E4 + G2m [][8]*big.Int // m*base points (x,y) +} + +func getTwistPoints() twistPoints { + _, _, _, g2aff := bls24315.Generators() + return twistPoints{ + G2x: [4]*big.Int{ + g2aff.X.B0.A0.BigInt(new(big.Int)), + g2aff.X.B0.A1.BigInt(new(big.Int)), + g2aff.X.B1.A0.BigInt(new(big.Int)), + g2aff.X.B1.A1.BigInt(new(big.Int)), + }, + G2y: [4]*big.Int{ + g2aff.Y.B0.A0.BigInt(new(big.Int)), + g2aff.Y.B0.A1.BigInt(new(big.Int)), + g2aff.Y.B1.A0.BigInt(new(big.Int)), + g2aff.Y.B1.A1.BigInt(new(big.Int)), + }, + G2m: computedTwistTable, + } + +} diff --git a/std/algebra/native/sw_bls24315/inner_compute.go b/std/algebra/native/sw_bls24315/inner_compute.go new file mode 100644 index 0000000000..158ef72999 --- /dev/null +++ b/std/algebra/native/sw_bls24315/inner_compute.go @@ -0,0 +1,59 @@ +package sw_bls24315 + +import ( + "math/big" + + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" +) + +func computeCurveTable() [][2]*big.Int { + G1jac, _, _, _ := bls24315.Generators() + table := make([][2]*big.Int, 253) + tmp := new(bls24315.G1Jac).Set(&G1jac) + aff := new(bls24315.G1Affine) + jac := new(bls24315.G1Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} + +func computeTwistTable() [][8]*big.Int { + _, G2jac, _, _ := bls24315.Generators() + table := make([][8]*big.Int, 253) + tmp := new(bls24315.G2Jac).Set(&G2jac) + aff := new(bls24315.G2Affine) + jac := new(bls24315.G2Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [8]*big.Int{aff.X.B0.A0.BigInt(new(big.Int)), aff.X.B0.A1.BigInt(new(big.Int)), aff.X.B1.A0.BigInt(new(big.Int)), aff.X.B1.A1.BigInt(new(big.Int)), aff.Y.B0.A0.BigInt(new(big.Int)), aff.Y.B0.A1.BigInt(new(big.Int)), aff.Y.B1.A0.BigInt(new(big.Int)), aff.Y.B1.A1.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [8]*big.Int{aff.X.B0.A0.BigInt(new(big.Int)), aff.X.B0.A1.BigInt(new(big.Int)), aff.X.B1.A0.BigInt(new(big.Int)), aff.X.B1.A1.BigInt(new(big.Int)), aff.Y.B0.A0.BigInt(new(big.Int)), aff.Y.B0.A1.BigInt(new(big.Int)), aff.Y.B1.A0.BigInt(new(big.Int)), aff.Y.B1.A1.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [8]*big.Int{aff.X.B0.A0.BigInt(new(big.Int)), aff.X.B0.A1.BigInt(new(big.Int)), aff.X.B1.A0.BigInt(new(big.Int)), aff.X.B1.A1.BigInt(new(big.Int)), aff.Y.B0.A0.BigInt(new(big.Int)), aff.Y.B0.A1.BigInt(new(big.Int)), aff.Y.B1.A0.BigInt(new(big.Int)), aff.Y.B1.A1.BigInt(new(big.Int))} + } + } + return table[:] +} diff --git a/std/algebra/native/sw_bls24315/pairing.go b/std/algebra/native/sw_bls24315/pairing.go new file mode 100644 index 0000000000..c796688cb5 --- /dev/null +++ b/std/algebra/native/sw_bls24315/pairing.go @@ -0,0 +1,462 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sw_bls24315 + +import ( + "errors" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" +) + +// GT target group of the pairing +type GT = fields_bls24315.E24 + +const ateLoop = 3218079743 + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 + R0(x/y) + R1(1/y) = 0 instead of R0'*y + R1'*x + R2' = 0 This +// makes the multiplication by lines (MulBy034) and between lines (Mul034By034) +type lineEvaluation struct { + R0, R1 fields_bls24315.E4 +} + +// MillerLoop computes the product of n miller loops (n can be 1) +// ∏ᵢ { fᵢ_{x₀,Q}(P) } +func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return GT{}, errors.New("invalid inputs sizes") + } + + var ateLoop2NAF [33]int8 + ecc.NafDecomposition(big.NewInt(ateLoop), ateLoop2NAF[:]) + + var res GT + res.SetOne() + var prodLines [5]fields_bls24315.E4 + + var l1, l2 lineEvaluation + Qacc := make([]G2Affine, n) + Qneg := make([]G2Affine, n) + yInv := make([]frontend.Variable, n) + xOverY := make([]frontend.Variable, n) + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + Qneg[k].Neg(api, Q[k]) + // TODO: point P=(x,O) should be ruled out + yInv[k] = api.DivUnchecked(1, P[k].Y) + xOverY[k] = api.Mul(P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + // i = 32, separately to avoid an E24 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line to res) + Qacc[0], l1 = doubleStep(api, &Qacc[0]) + res.D1.C0.MulByFp(api, l1.R0, xOverY[0]) + res.D1.C1.MulByFp(api, l1.R1, yInv[0]) + + if n >= 2 { + // k = 1, separately to avoid MulBy034 (res × ℓ) + // (res is also a line at this point, so we use Mul034By034 ℓ × ℓ) + Qacc[1], l1 = doubleStep(api, &Qacc[1]) + + // line evaluation at P[1] + l1.R0.MulByFp(api, l1.R0, xOverY[1]) + l1.R1.MulByFp(api, l1.R1, yInv[1]) + + // ℓ × res + prodLines = *fields_bls24315.Mul034By034(api, l1.R0, l1.R1, res.D1.C0, res.D1.C1) + res.D0.C0 = prodLines[0] + res.D0.C1 = prodLines[1] + res.D0.C2 = prodLines[2] + res.D1.C0 = prodLines[3] + res.D1.C1 = prodLines[4] + + } + + if n >= 3 { + // k >= 2 + for k := 2; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + } + + // i = 30, separately to avoid a doubleStep + // (at this point Qacc = 2Q, so 2Qacc-Q=3Q is equivalent to Qacc+Q=3Q + // this means doubleAndAddStep is equivalent to addStep here) + res.Square(api, res) + for k := 0; k < n; k++ { + // l2 the line passing Qacc[k] and -Q + l2 = lineCompute(api, &Qacc[k], &Qneg[k]) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // Qacc[k] ← Qacc[k]+Q[k] and + // l1 the line ℓ passing Qacc[k] and Q[k] + Qacc[k], l1 = addStep(api, &Qacc[k], &Q[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + + for i := 29; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res.Square(api, res) + + switch ateLoop2NAF[i] { + case 0: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + case 1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = doubleAndAddStep(api, &Qacc[k], &Q[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + case -1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]-Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]-Q[k]) and Qacc[k] + Qacc[k], l1, l2 = doubleAndAddStep(api, &Qacc[k], &Qneg[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + default: + return GT{}, errors.New("invalid loopCounter") + } + } + + // i = 0 + res.Square(api, res) + for k := 0; k < n; k++ { + // l1 the line ℓ passing Qacc[k] and -Q[k] + // l2 the line ℓ passing (Qacc[k]-Q[k]) and Qacc[k] + l1, l2 = linesCompute(api, &Qacc[k], &Qneg[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + + res.Conjugate(api, res) + + return res, nil +} + +// FinalExponentiation computes the exponentiation e1ᵈ +// where d = (p²⁴-1)/r = (p²⁴-1)/Φ₂₄(p) ⋅ Φ₂₄(p)/r = (p¹²-1)(p⁴+1)(p⁸ - p⁴ +1)/r +// we use instead d=s ⋅ (p¹²-1)(p⁴+1)(p⁸ - p⁴ +1)/r +// where s is the cofactor 3 (Hayashida et al.) +func FinalExponentiation(api frontend.API, e1 GT) GT { + const genT = ateLoop + result := e1 + + // https://eprint.iacr.org/2012/232.pdf, section 7 + var t [9]GT + + // easy part + // (p¹²-1)(p⁴+1) + t[0].Conjugate(api, result) + t[0].DivUnchecked(api, t[0], result) + result.FrobeniusQuad(api, t[0]). + Mul(api, result, t[0]) + + // hard part (api, up to permutation) + // Daiki Hayashida and Kenichiro Hayasaka + // and Tadanori Teruya + // https://eprint.iacr.org/2020/875.pdf + // 3(p⁸ - p⁴ +1)/r = (x₀-1)² * (x₀+p) * (x₀²+p²) * (x₀⁴+p⁴-1) + 3 + t[0].CyclotomicSquare(api, result) + t[1].Expt(api, result, genT) + t[2].Conjugate(api, result) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Conjugate(api, t[1]) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Frobenius(api, t[1]) + t[1].Mul(api, t[1], t[2]) + result.Mul(api, result, t[0]) + t[0].Expt(api, t[1], genT) + t[2].Expt(api, t[0], genT) + t[0].FrobeniusSquare(api, t[1]) + t[2].Mul(api, t[0], t[2]) + t[1].Expt(api, t[2], genT) + t[1].Expt(api, t[1], genT) + t[1].Expt(api, t[1], genT) + t[1].Expt(api, t[1], genT) + t[0].FrobeniusQuad(api, t[2]) + t[0].Mul(api, t[0], t[1]) + t[2].Conjugate(api, t[2]) + t[0].Mul(api, t[0], t[2]) + result.Mul(api, result, t[0]) + + return result +} + +// PairingCheck calculates the reduced pairing for a set of points and returns True if the result is One +// ∏ᵢ e(Pᵢ, Qᵢ) =? 1 +// +// This function doesn't check that the inputs are in the correct subgroup. See IsInSubGroup. +func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + f, err := MillerLoop(api, P, Q) + if err != nil { + return GT{}, err + } + return FinalExponentiation(api, f), nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3, x4, y4 fields_bls24315.E4 + var line1, line2 lineEvaluation + var p G2Affine + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute x4 = lambda2**2-x1-x3 + x4.Square(api, l2). + Sub(api, x4, p1.X). + Sub(api, x4, x3) + + // compute y4 = lambda2*(x1 - x4)-y1 + y4.Sub(api, p1.X, x4). + Mul(api, l2, y4). + Sub(api, y4, p1.Y) + + p.X = x4 + p.Y = y4 + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return p, line1, line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleStep(api frontend.API, p1 *G2Affine) (G2Affine, lineEvaluation) { + + var n, d, l, xr, yr fields_bls24315.E4 + var p G2Affine + var line lineEvaluation + + // lambda = 3*p1.x**2/2*p.y + n.Square(api, p1.X).MulByFp(api, n, 3) + d.MulByFp(api, p1.Y, 2) + l.DivUnchecked(api, n, d) + + // xr = lambda**2-2*p1.x + xr.Square(api, l). + Sub(api, xr, p1.X). + Sub(api, xr, p1.X) + + // yr = lambda*(p.x-xr)-p.y + yr.Sub(api, p1.X, xr). + Mul(api, l, yr). + Sub(api, yr, p1.Y) + + p.X = xr + p.Y = yr + + line.R0.Neg(api, l) + line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) + + return p, line + +} + +// addStep adds two points in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func addStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, lineEvaluation) { + + var p2ypy, p2xpx, λ, λλ, pxrx, λpxrx, xr, yr fields_bls24315.E4 + // compute λ = (y2-y1)/(x2-x1) + p2ypy.Sub(api, p2.Y, p1.Y) + p2xpx.Sub(api, p2.X, p1.X) + λ.DivUnchecked(api, p2ypy, p2xpx) + + // xr = λ²-x1-x2 + λλ.Square(api, λ) + p2xpx.Add(api, p1.X, p2.X) + xr.Sub(api, λλ, p2xpx) + + // yr = λ(x1-xr) - y1 + pxrx.Sub(api, p1.X, xr) + λpxrx.Mul(api, λ, pxrx) + yr.Sub(api, λpxrx, p1.Y) + + var res G2Affine + res.X = xr + res.Y = yr + + var line lineEvaluation + line.R0.Neg(api, λ) + line.R1.Mul(api, λ, p1.X) + line.R1.Sub(api, line.R1, p1.Y) + + return res, line + +} + +// linesCompute computes the lines that goes through p1 and p2, and (p1+p2) and p1 but does not compute 2p1+p2 +func linesCompute(api frontend.API, p1, p2 *G2Affine) (lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3 fields_bls24315.E4 + var line1, line2 lineEvaluation + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return line1, line2 +} + +// lineCompute computes the line that goes through p1 and p2 but does not compute p1+p2 +func lineCompute(api frontend.API, p1, p2 *G2Affine) lineEvaluation { + + var qypy, qxpx, λ fields_bls24315.E4 + + // compute λ = (y2-y1)/(x2-x1) + qypy.Sub(api, p2.Y, p1.Y) + qxpx.Sub(api, p2.X, p1.X) + λ.DivUnchecked(api, qypy, qxpx) + + var line lineEvaluation + line.R0.Neg(api, λ) + line.R1.Mul(api, λ, p1.X) + line.R1.Sub(api, line.R1, p1.Y) + + return line + +} diff --git a/std/algebra/sw_bls24315/pairing_test.go b/std/algebra/native/sw_bls24315/pairing_test.go similarity index 98% rename from std/algebra/sw_bls24315/pairing_test.go rename to std/algebra/native/sw_bls24315/pairing_test.go index c7bb53b4a5..44233d888d 100644 --- a/std/algebra/sw_bls24315/pairing_test.go +++ b/std/algebra/native/sw_bls24315/pairing_test.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" "github.com/consensys/gnark/test" ) diff --git a/std/algebra/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go similarity index 100% rename from std/algebra/twistededwards/curve.go rename to std/algebra/native/twistededwards/curve.go diff --git a/std/algebra/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go similarity index 100% rename from std/algebra/twistededwards/curve_test.go rename to std/algebra/native/twistededwards/curve_test.go diff --git a/std/algebra/native/twistededwards/doc.go b/std/algebra/native/twistededwards/doc.go new file mode 100644 index 0000000000..95b1486681 --- /dev/null +++ b/std/algebra/native/twistededwards/doc.go @@ -0,0 +1,8 @@ +// Package twistededwards implements the arithmetic of twisted Edwards curves +// in native fields. This uses associated twisted Edwards curves defined over +// the scalar field of the SNARK curves. +// +// Examples: +// Jubjub, Bandersnatch (a twisted Edwards) is defined over BLS12-381's scalar field +// Baby-Jubjub (a twisted Edwards) is defined over BN254's salar fields +package twistededwards diff --git a/std/algebra/twistededwards/point.go b/std/algebra/native/twistededwards/point.go similarity index 100% rename from std/algebra/twistededwards/point.go rename to std/algebra/native/twistededwards/point.go diff --git a/std/algebra/twistededwards/scalarmul_glv.go b/std/algebra/native/twistededwards/scalarmul_glv.go similarity index 98% rename from std/algebra/twistededwards/scalarmul_glv.go rename to std/algebra/native/twistededwards/scalarmul_glv.go index b429d866ce..7b959a2db4 100644 --- a/std/algebra/twistededwards/scalarmul_glv.go +++ b/std/algebra/native/twistededwards/scalarmul_glv.go @@ -22,7 +22,7 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -82,7 +82,7 @@ var DecomposeScalar = func(scalarField *big.Int, inputs []*big.Int, res []*big.I } func init() { - hint.Register(DecomposeScalar) + solver.RegisterHint(DecomposeScalar) } // ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve diff --git a/std/algebra/twistededwards/twistededwards.go b/std/algebra/native/twistededwards/twistededwards.go similarity index 100% rename from std/algebra/twistededwards/twistededwards.go rename to std/algebra/native/twistededwards/twistededwards.go diff --git a/std/algebra/sw_bls12377/doc.go b/std/algebra/sw_bls12377/doc.go deleted file mode 100644 index c9270127e8..0000000000 --- a/std/algebra/sw_bls12377/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package sw (short weierstrass) -package sw_bls12377 diff --git a/std/algebra/sw_bls12377/pairing.go b/std/algebra/sw_bls12377/pairing.go deleted file mode 100644 index 7ca86131e7..0000000000 --- a/std/algebra/sw_bls12377/pairing.go +++ /dev/null @@ -1,243 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sw_bls12377 - -import ( - "errors" - "math/big" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls12377" -) - -// GT target group of the pairing -type GT = fields_bls12377.E12 - -const ateLoop = 9586122913090633729 - -// LineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) -type LineEvaluation struct { - R0, R1 fields_bls12377.E2 -} - -// MillerLoop computes the product of n miller loops (n can be 1) -func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - // check input size match - n := len(P) - if n == 0 || n != len(Q) { - return GT{}, errors.New("invalid inputs sizes") - } - - var ateLoopBin [64]uint - var ateLoopBigInt big.Int - ateLoopBigInt.SetUint64(ateLoop) - for i := 0; i < 64; i++ { - ateLoopBin[i] = ateLoopBigInt.Bit(i) - } - - var res GT - res.SetOne() - - var l1, l2 LineEvaluation - Qacc := make([]G2Affine, n) - yInv := make([]frontend.Variable, n) - xOverY := make([]frontend.Variable, n) - for k := 0; k < n; k++ { - Qacc[k] = Q[k] - yInv[k] = api.DivUnchecked(1, P[k].Y) - xOverY[k] = api.DivUnchecked(P[k].X, P[k].Y) - } - - // k = 0 - Qacc[0], l1 = DoubleStep(api, &Qacc[0]) - res.C1.B0.MulByFp(api, l1.R0, xOverY[0]) - res.C1.B1.MulByFp(api, l1.R1, yInv[0]) - - if n >= 2 { - // k = 1 - Qacc[1], l1 = DoubleStep(api, &Qacc[1]) - l1.R0.MulByFp(api, l1.R0, xOverY[1]) - l1.R1.MulByFp(api, l1.R1, yInv[1]) - res.Mul034By034(api, l1.R0, l1.R1, res.C1.B0, res.C1.B1) - } - - if n >= 3 { - // k >= 2 - for k := 2; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - } - - for i := len(ateLoopBin) - 3; i >= 0; i-- { - res.Square(api, res) - - if ateLoopBin[i] == 0 { - for k := 0; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - continue - } - - for k := 0; k < n; k++ { - Qacc[k], l1, l2 = DoubleAndAddStep(api, &Qacc[k], &Q[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - l2.R0.MulByFp(api, l2.R0, xOverY[k]) - l2.R1.MulByFp(api, l2.R1, yInv[k]) - res.MulBy034(api, l2.R0, l2.R1) - } - } - - return res, nil -} - -// FinalExponentiation computes the final expo x**(p**6-1)(p**2+1)(p**4 - p**2 +1)/r -func FinalExponentiation(api frontend.API, e1 GT) GT { - const genT = ateLoop - - result := e1 - - // https://eprint.iacr.org/2016/130.pdf - var t [3]GT - - // easy part - t[0].Conjugate(api, result) - t[0].DivUnchecked(api, t[0], result) - result.FrobeniusSquare(api, t[0]). - Mul(api, result, t[0]) - - // hard part (up to permutation) - // Daiki Hayashida and Kenichiro Hayasaka - // and Tadanori Teruya - // https://eprint.iacr.org/2020/875.pdf - t[0].CyclotomicSquare(api, result) - t[1].Expt(api, result, genT) - t[2].Conjugate(api, result) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Conjugate(api, t[1]) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Frobenius(api, t[1]) - t[1].Mul(api, t[1], t[2]) - result.Mul(api, result, t[0]) - t[0].Expt(api, t[1], genT) - t[2].Expt(api, t[0], genT) - t[0].FrobeniusSquare(api, t[1]) - t[1].Conjugate(api, t[1]) - t[1].Mul(api, t[1], t[2]) - t[1].Mul(api, t[1], t[0]) - result.Mul(api, result, t[1]) - - return result -} - -// Pair calculates the reduced pairing for a set of points -func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - f, err := MillerLoop(api, P, Q) - if err != nil { - return GT{}, err - } - return FinalExponentiation(api, f), nil -} - -// DoubleAndAddStep -func DoubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, LineEvaluation, LineEvaluation) { - - var n, d, l1, l2, x3, x4, y4 fields_bls12377.E2 - var line1, line2 LineEvaluation - var p G2Affine - - // compute lambda1 = (y2-y1)/(x2-x1) - n.Sub(api, p1.Y, p2.Y) - d.Sub(api, p1.X, p2.X) - l1.DivUnchecked(api, n, d) - - // x3 =lambda1**2-p1.x-p2.x - x3.Square(api, l1). - Sub(api, x3, p1.X). - Sub(api, x3, p2.X) - - // omit y3 computation - - // compute line1 - line1.R0.Neg(api, l1) - line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) - - // compute lambda2 = -lambda1-2*y1/(x3-x1) - n.Double(api, p1.Y) - d.Sub(api, x3, p1.X) - l2.DivUnchecked(api, n, d) - l2.Add(api, l2, l1).Neg(api, l2) - - // compute x4 = lambda2**2-x1-x3 - x4.Square(api, l2). - Sub(api, x4, p1.X). - Sub(api, x4, x3) - - // compute y4 = lambda2*(x1 - x4)-y1 - y4.Sub(api, p1.X, x4). - Mul(api, l2, y4). - Sub(api, y4, p1.Y) - - p.X = x4 - p.Y = y4 - - // compute line2 - line2.R0.Neg(api, l2) - line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) - - return p, line1, line2 -} - -func DoubleStep(api frontend.API, p1 *G2Affine) (G2Affine, LineEvaluation) { - - var n, d, l, xr, yr fields_bls12377.E2 - var p G2Affine - var line LineEvaluation - - // lambda = 3*p1.x**2/2*p.y - n.Square(api, p1.X).MulByFp(api, n, 3) - d.MulByFp(api, p1.Y, 2) - l.DivUnchecked(api, n, d) - - // xr = lambda**2-2*p1.x - xr.Square(api, l). - Sub(api, xr, p1.X). - Sub(api, xr, p1.X) - - // yr = lambda*(p.x-xr)-p.y - yr.Sub(api, p1.X, xr). - Mul(api, l, yr). - Sub(api, yr, p1.Y) - - p.X = xr - p.Y = yr - - line.R0.Neg(api, l) - line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) - - return p, line - -} diff --git a/std/algebra/sw_bls24315/doc.go b/std/algebra/sw_bls24315/doc.go deleted file mode 100644 index 279e09879b..0000000000 --- a/std/algebra/sw_bls24315/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package sw (short weierstrass) -package sw_bls24315 diff --git a/std/algebra/sw_bls24315/pairing.go b/std/algebra/sw_bls24315/pairing.go deleted file mode 100644 index b1e1fb11e0..0000000000 --- a/std/algebra/sw_bls24315/pairing.go +++ /dev/null @@ -1,259 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sw_bls24315 - -import ( - "errors" - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls24315" -) - -// GT target group of the pairing -type GT = fields_bls24315.E24 - -const ateLoop = 3218079743 - -// LineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) -type LineEvaluation struct { - R0, R1 fields_bls24315.E4 -} - -// MillerLoop computes the product of n miller loops (n can be 1) -func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - // check input size match - n := len(P) - if n == 0 || n != len(Q) { - return GT{}, errors.New("invalid inputs sizes") - } - - var ateLoop2NAF [33]int8 - ecc.NafDecomposition(big.NewInt(ateLoop), ateLoop2NAF[:]) - - var res GT - res.SetOne() - - var l1, l2 LineEvaluation - Qacc := make([]G2Affine, n) - Qneg := make([]G2Affine, n) - yInv := make([]frontend.Variable, n) - xOverY := make([]frontend.Variable, n) - for k := 0; k < n; k++ { - Qacc[k] = Q[k] - Qneg[k].Neg(api, Q[k]) - yInv[k] = api.DivUnchecked(1, P[k].Y) - xOverY[k] = api.DivUnchecked(P[k].X, P[k].Y) - } - - // k = 0 - Qacc[0], l1 = DoubleStep(api, &Qacc[0]) - res.D1.C0.MulByFp(api, l1.R0, xOverY[0]) - res.D1.C1.MulByFp(api, l1.R1, yInv[0]) - - if n >= 2 { - // k = 1 - Qacc[1], l1 = DoubleStep(api, &Qacc[1]) - l1.R0.MulByFp(api, l1.R0, xOverY[1]) - l1.R1.MulByFp(api, l1.R1, yInv[1]) - res.Mul034By034(api, l1.R0, l1.R1, res.D1.C0, res.D1.C1) - } - - if n >= 3 { - // k >= 2 - for k := 2; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - } - - for i := len(ateLoop2NAF) - 3; i >= 0; i-- { - res.Square(api, res) - - if ateLoop2NAF[i] == 0 { - for k := 0; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - } else if ateLoop2NAF[i] == 1 { - for k := 0; k < n; k++ { - Qacc[k], l1, l2 = DoubleAndAddStep(api, &Qacc[k], &Q[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - l2.R0.MulByFp(api, l2.R0, xOverY[k]) - l2.R1.MulByFp(api, l2.R1, yInv[k]) - res.MulBy034(api, l2.R0, l2.R1) - } - } else { - for k := 0; k < n; k++ { - Qacc[k], l1, l2 = DoubleAndAddStep(api, &Qacc[k], &Qneg[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - l2.R0.MulByFp(api, l2.R0, xOverY[k]) - l2.R1.MulByFp(api, l2.R1, yInv[k]) - res.MulBy034(api, l2.R0, l2.R1) - } - } - } - - res.Conjugate(api, res) - - return res, nil -} - -// FinalExponentiation computes the final expo x**(p**12-1)(p**4+1)(p**8 - p**4 +1)/r -func FinalExponentiation(api frontend.API, e1 GT) GT { - const genT = ateLoop - result := e1 - - // https://eprint.iacr.org/2012/232.pdf, section 7 - var t [9]GT - - // easy part - t[0].Conjugate(api, result) - t[0].DivUnchecked(api, t[0], result) - result.FrobeniusQuad(api, t[0]). - Mul(api, result, t[0]) - - // hard part (api, up to permutation) - // Daiki Hayashida and Kenichiro Hayasaka - // and Tadanori Teruya - // https://eprint.iacr.org/2020/875.pdf - // 3*Phi_24(p)/r = (u-1)² * (u+p) * (u²+p²) * (u⁴+p⁴-1) + 3 - t[0].CyclotomicSquare(api, result) - t[1].Expt(api, result, genT) - t[2].Conjugate(api, result) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Conjugate(api, t[1]) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Frobenius(api, t[1]) - t[1].Mul(api, t[1], t[2]) - result.Mul(api, result, t[0]) - t[0].Expt(api, t[1], genT) - t[2].Expt(api, t[0], genT) - t[0].FrobeniusSquare(api, t[1]) - t[2].Mul(api, t[0], t[2]) - t[1].Expt(api, t[2], genT) - t[1].Expt(api, t[1], genT) - t[1].Expt(api, t[1], genT) - t[1].Expt(api, t[1], genT) - t[0].FrobeniusQuad(api, t[2]) - t[0].Mul(api, t[0], t[1]) - t[2].Conjugate(api, t[2]) - t[0].Mul(api, t[0], t[2]) - result.Mul(api, result, t[0]) - - return result -} - -// Pair calculates the reduced pairing for a set of points -func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - f, err := MillerLoop(api, P, Q) - if err != nil { - return GT{}, err - } - return FinalExponentiation(api, f), nil -} - -// DoubleAndAddStep -func DoubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, LineEvaluation, LineEvaluation) { - - var n, d, l1, l2, x3, x4, y4 fields_bls24315.E4 - var line1, line2 LineEvaluation - var p G2Affine - - // compute lambda1 = (y2-y1)/(x2-x1) - n.Sub(api, p1.Y, p2.Y) - d.Sub(api, p1.X, p2.X) - l1.DivUnchecked(api, n, d) - - // x3 =lambda1**2-p1.x-p2.x - x3.Square(api, l1). - Sub(api, x3, p1.X). - Sub(api, x3, p2.X) - - // omit y3 computation - - // compute line1 - line1.R0.Neg(api, l1) - line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) - - // compute lambda2 = -lambda1-2*y1/(x3-x1) - n.Double(api, p1.Y) - d.Sub(api, x3, p1.X) - l2.DivUnchecked(api, n, d) - l2.Add(api, l2, l1).Neg(api, l2) - - // compute x4 = lambda2**2-x1-x3 - x4.Square(api, l2). - Sub(api, x4, p1.X). - Sub(api, x4, x3) - - // compute y4 = lambda2*(x1 - x4)-y1 - y4.Sub(api, p1.X, x4). - Mul(api, l2, y4). - Sub(api, y4, p1.Y) - - p.X = x4 - p.Y = y4 - - // compute line2 - line2.R0.Neg(api, l2) - line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) - - return p, line1, line2 -} - -func DoubleStep(api frontend.API, p1 *G2Affine) (G2Affine, LineEvaluation) { - - var n, d, l, xr, yr fields_bls24315.E4 - var p G2Affine - var line LineEvaluation - - // lambda = 3*p1.x**2/2*p.y - n.Square(api, p1.X).MulByFp(api, n, 3) - d.MulByFp(api, p1.Y, 2) - l.DivUnchecked(api, n, d) - - // xr = lambda**2-2*p1.x - xr.Square(api, l). - Sub(api, xr, p1.X). - Sub(api, xr, p1.X) - - // yr = lambda*(p.x-xr)-p.y - yr.Sub(api, p1.X, xr). - Mul(api, l, yr). - Sub(api, yr, p1.Y) - - p.X = xr - p.Y = yr - - line.R0.Neg(api, l) - line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) - - return p, line - -} diff --git a/std/algebra/weierstrass/params.go b/std/algebra/weierstrass/params.go deleted file mode 100644 index aadf2bce6a..0000000000 --- a/std/algebra/weierstrass/params.go +++ /dev/null @@ -1,71 +0,0 @@ -package weierstrass - -import ( - "math/big" - - "github.com/consensys/gnark/std/math/emulated" -) - -// CurveParams defines parameters of an elliptic curve in short Weierstrass form -// given by the equation -// -// Y² = X³ + aX + b -// -// The base point is defined by (Gx, Gy). -type CurveParams struct { - A *big.Int // a in curve equation - B *big.Int // b in curve equation - Gx *big.Int // base point x - Gy *big.Int // base point y -} - -// GetSecp256k1Params returns curve parameters for the curve secp256k1. When -// initialising new curve, use the base field [emulated.Secp256k1Fp] and scalar -// field [emulated.Secp256k1Fr]. -func GetSecp256k1Params() CurveParams { - gx, _ := new(big.Int).SetString("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798", 16) - gy, _ := new(big.Int).SetString("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8", 16) - return CurveParams{ - A: big.NewInt(0), - B: big.NewInt(7), - Gx: gx, - Gy: gy, - } -} - -// GetBN254Params returns the curve parameters for the curve BN254 (alt_bn128). -// When initialising new curve, use the base field [emulated.BN254Fp] and scalar -// field [emulated.BN254Fr]. -func GetBN254Params() CurveParams { - gx := big.NewInt(1) - gy := big.NewInt(2) - return CurveParams{ - A: big.NewInt(0), - B: big.NewInt(3), - Gx: gx, - Gy: gy, - } -} - -// GetCurveParams returns suitable curve parameters given the parametric type Base as base field. -func GetCurveParams[Base emulated.FieldParams]() CurveParams { - var t Base - switch t.Modulus().Text(16) { - case "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f": - return secp256k1Params - case "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47": - return bn254Params - default: - panic("no stored parameters") - } -} - -var ( - secp256k1Params CurveParams - bn254Params CurveParams -) - -func init() { - secp256k1Params = GetSecp256k1Params() - bn254Params = GetBN254Params() -} diff --git a/std/algebra/weierstrass/point.go b/std/algebra/weierstrass/point.go deleted file mode 100644 index 8c91ed2761..0000000000 --- a/std/algebra/weierstrass/point.go +++ /dev/null @@ -1,163 +0,0 @@ -package weierstrass - -import ( - "fmt" - "math/big" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/emulated" -) - -// New returns a new [Curve] instance over the base field Base and scalar field -// Scalars defined by the curve parameters params. It returns an error if -// initialising the field emulation fails (for example, when the native field is -// too small) or when the curve parameters are incompatible with the fields. -func New[Base, Scalars emulated.FieldParams](api frontend.API, params CurveParams) (*Curve[Base, Scalars], error) { - ba, err := emulated.NewField[Base](api) - if err != nil { - return nil, fmt.Errorf("new base api: %w", err) - } - sa, err := emulated.NewField[Scalars](api) - if err != nil { - return nil, fmt.Errorf("new scalar api: %w", err) - } - Gx := emulated.ValueOf[Base](params.Gx) - Gy := emulated.ValueOf[Base](params.Gy) - return &Curve[Base, Scalars]{ - params: params, - api: api, - baseApi: ba, - scalarApi: sa, - g: AffinePoint[Base]{ - X: Gx, - Y: Gy, - }, - a: emulated.ValueOf[Base](params.A), - addA: params.A.Cmp(big.NewInt(0)) != 0, - }, nil -} - -// Curve is an initialised curve which allows performing group operations. -type Curve[Base, Scalars emulated.FieldParams] struct { - // params is the parameters of the curve - params CurveParams - // api is the native api, we construct it ourselves to be sure - api frontend.API - // baseApi is the api for point operations - baseApi *emulated.Field[Base] - // scalarApi is the api for scalar operations - scalarApi *emulated.Field[Scalars] - - // g is the generator (base point) of the curve. - g AffinePoint[Base] - - a emulated.Element[Base] - addA bool -} - -// Generator returns the base point of the curve. The method does not copy and -// modifying the returned element leads to undefined behaviour! -func (c *Curve[B, S]) Generator() *AffinePoint[B] { - return &c.g -} - -// AffinePoint represents a point on the elliptic curve. We do not check that -// the point is actually on the curve. -type AffinePoint[Base emulated.FieldParams] struct { - X, Y emulated.Element[Base] -} - -// Neg returns an inverse of p. It doesn't modify p. -func (c *Curve[B, S]) Neg(p *AffinePoint[B]) *AffinePoint[B] { - return &AffinePoint[B]{ - X: p.X, - Y: *c.baseApi.Neg(&p.Y), - } -} - -// AssertIsEqual asserts that p and q are the same point. -func (c *Curve[B, S]) AssertIsEqual(p, q *AffinePoint[B]) { - c.baseApi.AssertIsEqual(&p.X, &q.X) - c.baseApi.AssertIsEqual(&p.Y, &q.Y) -} - -// Add adds q and r and returns it. -func (c *Curve[B, S]) Add(q, r *AffinePoint[B]) *AffinePoint[B] { - // compute lambda = (p1.y-p.y)/(p1.x-p.x) - p1ypy := c.baseApi.Sub(&r.Y, &q.Y) - p1xpx := c.baseApi.Sub(&r.X, &q.X) - lambda := c.baseApi.Div(p1ypy, p1xpx) - - // xr = lambda**2-p.x-p1.x - lambdaSq := c.baseApi.MulMod(lambda, lambda) - qxrx := c.baseApi.Add(&q.X, &r.X) - xr := c.baseApi.Sub(lambdaSq, qxrx) - - // p.y = lambda(p.x-xr) - p.y - pxxr := c.baseApi.Sub(&q.X, xr) - lpxxr := c.baseApi.MulMod(lambda, pxxr) - py := c.baseApi.Sub(lpxxr, &q.Y) - - return &AffinePoint[B]{ - X: *c.baseApi.Reduce(xr), - Y: *c.baseApi.Reduce(py), - } -} - -// Double doubles p and return it. It doesn't modify p. -func (c *Curve[B, S]) Double(p *AffinePoint[B]) *AffinePoint[B] { - - // compute lambda = (3*p1.x**2+a)/2*p1.y, here we assume a=0 (j invariant 0 curve) - xSq3a := c.baseApi.MulMod(&p.X, &p.X) - xSq3a = c.baseApi.MulConst(xSq3a, big.NewInt(3)) - if c.addA { - xSq3a = c.baseApi.Add(xSq3a, &c.a) - } - y2 := c.baseApi.MulConst(&p.Y, big.NewInt(2)) - lambda := c.baseApi.Div(xSq3a, y2) - - // xr = lambda**2-p1.x-p1.x - x2 := c.baseApi.MulConst(&p.X, big.NewInt(2)) - lambdaSq := c.baseApi.MulMod(lambda, lambda) - xr := c.baseApi.Sub(lambdaSq, x2) - - // p.y = lambda(p.x-xr) - p.y - pxxr := c.baseApi.Sub(&p.X, xr) - lpxxr := c.baseApi.MulMod(lambda, pxxr) - py := c.baseApi.Sub(lpxxr, &p.Y) - - return &AffinePoint[B]{ - X: *c.baseApi.Reduce(xr), - Y: *c.baseApi.Reduce(py), - } -} - -// Select selects between p and q given the selector b. If b == 0, then returns -// p and q otherwise. -func (c *Curve[B, S]) Select(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { - x := c.baseApi.Select(b, &p.X, &q.X) - y := c.baseApi.Select(b, &p.Y, &q.Y) - return &AffinePoint[B]{ - X: *x, - Y: *y, - } -} - -// ScalarMul computes s * p and returns it. It doesn't modify p nor s. -func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *AffinePoint[B] { - res := p - acc := c.Double(p) - - var st S - sr := c.scalarApi.Reduce(s) - sBits := c.scalarApi.ToBits(sr) - for i := 1; i < st.Modulus().BitLen(); i++ { - tmp := c.Add(res, acc) - res = c.Select(sBits[i], tmp, res) - acc = c.Double(acc) - } - - tmp := c.Add(res, c.Neg(p)) - res = c.Select(sBits[0], res, tmp) - return res -} diff --git a/std/algebra/weierstrass/point_test.go b/std/algebra/weierstrass/point_test.go deleted file mode 100644 index 28faf42109..0000000000 --- a/std/algebra/weierstrass/point_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package weierstrass - -import ( - "math/big" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/secp256k1" - "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/gnark/test" -) - -var testCurve = ecc.BN254 - -type NegTest[T, S emulated.FieldParams] struct { - P, Q AffinePoint[T] -} - -func (c *NegTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.Neg(&c.P) - cr.AssertIsEqual(res, &c.Q) - return nil -} - -func TestNeg(t *testing.T) { - assert := test.NewAssert(t) - _, g := secp256k1.Generators() - var yn fp.Element - yn.Neg(&g.Y) - circuit := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](yn), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) -} - -type AddTest[T, S emulated.FieldParams] struct { - P, Q, R AffinePoint[T] -} - -func (c *AddTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.Add(&c.P, &c.Q) - cr.AssertIsEqual(res, &c.R) - return nil -} - -func TestAdd(t *testing.T) { - assert := test.NewAssert(t) - var dJac, aJac secp256k1.G1Jac - g, _ := secp256k1.Generators() - dJac.Double(&g) - aJac.Set(&dJac). - AddAssign(&g) - var dAff, aAff secp256k1.G1Affine - dAff.FromJacobian(&dJac) - aAff.FromJacobian(&aJac) - circuit := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), - }, - R: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](aAff.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](aAff.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) -} - -type DoubleTest[T, S emulated.FieldParams] struct { - P, Q AffinePoint[T] -} - -func (c *DoubleTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.Double(&c.P) - cr.AssertIsEqual(res, &c.Q) - return nil -} - -func TestDouble(t *testing.T) { - assert := test.NewAssert(t) - g, _ := secp256k1.Generators() - var dJac secp256k1.G1Jac - dJac.Double(&g) - var dAff secp256k1.G1Affine - dAff.FromJacobian(&dJac) - circuit := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) -} - -type ScalarMulTest[T, S emulated.FieldParams] struct { - P, Q AffinePoint[T] - S emulated.Element[S] -} - -func (c *ScalarMulTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.ScalarMul(&c.P, &c.S) - cr.AssertIsEqual(res, &c.Q) - return nil -} - -func TestScalarMul(t *testing.T) { - assert := test.NewAssert(t) - _, g := secp256k1.Generators() - s, ok := new(big.Int).SetString("44693544921776318736021182399461740191514036429448770306966433218654680512345", 10) - assert.True(ok) - var S secp256k1.G1Affine - S.ScalarMultiplication(&g, s) - - circuit := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - S: emulated.ValueOf[emulated.Secp256k1Fr](s), - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) - _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit) - assert.NoError(err) -} - -func TestScalarMul2(t *testing.T) { - assert := test.NewAssert(t) - s, ok := new(big.Int).SetString("14108069686105661647148607545884343550368786660735262576656400957535521042679", 10) - assert.True(ok) - var res bn254.G1Affine - _, _, gen, _ := bn254.Generators() - res.ScalarMultiplication(&gen, s) - - circuit := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{} - witness := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{ - S: emulated.ValueOf[emulated.BN254Fr](s), - P: AffinePoint[emulated.BN254Fp]{ - X: emulated.ValueOf[emulated.BN254Fp](gen.X), - Y: emulated.ValueOf[emulated.BN254Fp](gen.Y), - }, - Q: AffinePoint[emulated.BN254Fp]{ - X: emulated.ValueOf[emulated.BN254Fp](res.X), - Y: emulated.ValueOf[emulated.BN254Fp](res.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) - _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit) - assert.NoError(err) -} diff --git a/std/commitments/fri/fri.go b/std/commitments/fri/fri.go index 96b8ac2c3d..93dd0b301f 100644 --- a/std/commitments/fri/fri.go +++ b/std/commitments/fri/fri.go @@ -50,7 +50,7 @@ type RadixTwoFri struct { // hash function that is used for Fiat Shamir and for committing to // the oracles. - h hash.Hash + h hash.FieldHasher // nbSteps number of interactions between the prover and the verifier nbSteps int @@ -66,7 +66,7 @@ type RadixTwoFri struct { // NewRadixTwoFri creates an FFT-like oracle proof of proximity. // * h is the hash function that is used for the Merkle proofs // * gen is the generator of the cyclic group of unity of size \rho * size -func NewRadixTwoFri(size uint64, h hash.Hash, gen big.Int) RadixTwoFri { +func NewRadixTwoFri(size uint64, h hash.FieldHasher, gen big.Int) RadixTwoFri { var res RadixTwoFri diff --git a/std/commitments/fri/utils.go b/std/commitments/fri/utils.go index 13efcf98d9..24c0716073 100644 --- a/std/commitments/fri/utils.go +++ b/std/commitments/fri/utils.go @@ -3,7 +3,7 @@ package fri import ( "math/big" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -84,5 +84,5 @@ var DeriveQueriesPositions = func(_ *big.Int, inputs []*big.Int, res []*big.Int) } func init() { - hint.Register(DeriveQueriesPositions) + solver.RegisterHint(DeriveQueriesPositions) } diff --git a/std/commitments/kzg_bls12377/verifier.go b/std/commitments/kzg_bls12377/verifier.go index 49f9ed5b19..e7cffd85a3 100644 --- a/std/commitments/kzg_bls12377/verifier.go +++ b/std/commitments/kzg_bls12377/verifier.go @@ -19,8 +19,8 @@ package kzg_bls12377 import ( "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" ) // Digest commitment of a polynomial. @@ -28,7 +28,6 @@ type Digest = sw_bls12377.G1Affine // VK verification key (G2 part of SRS) type VK struct { - G1 sw_bls12377.G1Affine // G₁ G2 [2]sw_bls12377.G2Affine // [G₂, [α]G₂] } @@ -53,7 +52,7 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front // [f(a)]G₁ var claimedValueG1Aff sw_bls12377.G1Affine - claimedValueG1Aff.ScalarMul(api, srs.G1, proof.ClaimedValue) + claimedValueG1Aff.ScalarMulBase(api, proof.ClaimedValue) // [f(α) - f(a)]G₁ var fminusfaG1 sw_bls12377.G1Affine @@ -64,17 +63,16 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front var negH sw_bls12377.G1Affine negH.Neg(api, proof.H) - // [α-a]G₂ - var alphaMinusaG2 sw_bls12377.G2Affine - alphaMinusaG2.ScalarMul(api, srs.G2[0], point). - Neg(api, alphaMinusaG2). - AddAssign(api, srs.G2[1]) + // [f(α) - f(a) + a*H(α)]G₁ + var totalG1 sw_bls12377.G1Affine + totalG1.ScalarMul(api, proof.H, point). + AddAssign(api, fminusfaG1) - // e([f(α) - f(a)]G₁, G₂).e([-H(α)]G₁, [α-a]G₂) ==? 1 + // e([f(α)-f(a)+aH(α)]G₁], G₂).e([-H(α)]G₁, [α]G₂) == 1 resPairing, _ := sw_bls12377.Pair( api, - []sw_bls12377.G1Affine{fminusfaG1, negH}, - []sw_bls12377.G2Affine{srs.G2[0], alphaMinusaG2}, + []sw_bls12377.G1Affine{totalG1, negH}, + []sw_bls12377.G2Affine{srs.G2[0], srs.G2[1]}, ) var one fields_bls12377.E12 diff --git a/std/commitments/kzg_bls12377/verifier_test.go b/std/commitments/kzg_bls12377/verifier_test.go index 55f9c9453a..09b7684b5e 100644 --- a/std/commitments/kzg_bls12377/verifier_test.go +++ b/std/commitments/kzg_bls12377/verifier_test.go @@ -69,17 +69,17 @@ func TestVerifierDynamic(t *testing.T) { } // commit to the polynomial - com, err := kzg.Commit(f, srs) + com, err := kzg.Commit(f, srs.Pk) assert.NoError(err) // create opening proof var point fr.Element point.SetRandom() - proof, err := kzg.Open(f, point, srs) + proof, err := kzg.Open(f, point, srs.Pk) assert.NoError(err) // check that the proof is correct - err = kzg.Verify(&com, &proof, point, srs) + err = kzg.Verify(&com, &proof, point, srs.Vk) if err != nil { t.Fatal(err) } @@ -98,17 +98,14 @@ func TestVerifierDynamic(t *testing.T) { witness.S = point.String() - witness.VerifKey.G1.X = srs.G1[0].X.String() - witness.VerifKey.G1.Y = srs.G1[0].Y.String() - - witness.VerifKey.G2[0].X.A0 = srs.G2[0].X.A0.String() - witness.VerifKey.G2[0].X.A1 = srs.G2[0].X.A1.String() - witness.VerifKey.G2[0].Y.A0 = srs.G2[0].Y.A0.String() - witness.VerifKey.G2[0].Y.A1 = srs.G2[0].Y.A1.String() - witness.VerifKey.G2[1].X.A0 = srs.G2[1].X.A0.String() - witness.VerifKey.G2[1].X.A1 = srs.G2[1].X.A1.String() - witness.VerifKey.G2[1].Y.A0 = srs.G2[1].Y.A0.String() - witness.VerifKey.G2[1].Y.A1 = srs.G2[1].Y.A1.String() + witness.VerifKey.G2[0].X.A0 = srs.Vk.G2[0].X.A0.String() + witness.VerifKey.G2[0].X.A1 = srs.Vk.G2[0].X.A1.String() + witness.VerifKey.G2[0].Y.A0 = srs.Vk.G2[0].Y.A0.String() + witness.VerifKey.G2[0].Y.A1 = srs.Vk.G2[0].Y.A1.String() + witness.VerifKey.G2[1].X.A0 = srs.Vk.G2[1].X.A0.String() + witness.VerifKey.G2[1].X.A1 = srs.Vk.G2[1].X.A1.String() + witness.VerifKey.G2[1].Y.A0 = srs.Vk.G2[1].Y.A0.String() + witness.VerifKey.G2[1].Y.A1 = srs.Vk.G2[1].Y.A1.String() // check if the circuit is solved var circuit verifierCircuit @@ -132,8 +129,6 @@ func TestVerifier(t *testing.T) { witness.Proof.ClaimedValue = "7211341386127354417397285211336133449231039596179023429378585109196698597268" witness.S = "4321" - witness.VerifKey.G1.X = "81937999373150964239938255573465948239988671502647976594219695644855304257327692006745978603320413799295628339695" - witness.VerifKey.G1.Y = "241266749859715473739788878240585681733927191168601896383759122102112907357779751001206799952863815012735208165030" witness.VerifKey.G2[0].X.A0 = "233578398248691099356572568220835526895379068987715365179118596935057653620464273615301663571204657964920925606294" witness.VerifKey.G2[0].X.A1 = "140913150380207355837477652521042157274541796891053068589147167627541651775299824604154852141315666357241556069118" witness.VerifKey.G2[0].Y.A0 = "63160294768292073209381361943935198908131692476676907196754037919244929611450776219210369229519898517858833747423" diff --git a/std/commitments/kzg_bls24315/verifier.go b/std/commitments/kzg_bls24315/verifier.go index 4d7625e2e9..260064b691 100644 --- a/std/commitments/kzg_bls24315/verifier.go +++ b/std/commitments/kzg_bls24315/verifier.go @@ -19,8 +19,8 @@ package kzg_bls24315 import ( "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls24315" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" ) // Digest commitment of a polynomial. @@ -28,7 +28,6 @@ type Digest = sw_bls24315.G1Affine // VK verification key (G2 part of SRS) type VK struct { - G1 sw_bls24315.G1Affine // G₁ G2 [2]sw_bls24315.G2Affine // [G₂, [α]G₂] } @@ -53,7 +52,7 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front // [f(a)]G₁ var claimedValueG1Aff sw_bls24315.G1Affine - claimedValueG1Aff.ScalarMul(api, srs.G1, proof.ClaimedValue) + claimedValueG1Aff.ScalarMulBase(api, proof.ClaimedValue) // [f(α) - f(a)]G₁ var fminusfaG1 sw_bls24315.G1Affine @@ -64,17 +63,16 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front var negH sw_bls24315.G1Affine negH.Neg(api, proof.H) - // [α-a]G₂ - var alphaMinusaG2 sw_bls24315.G2Affine - alphaMinusaG2.ScalarMul(api, srs.G2[0], point). - Neg(api, alphaMinusaG2). - AddAssign(api, srs.G2[1]) + // [f(α) - f(a) + a*H(α)]G₁ + var totalG1 sw_bls24315.G1Affine + totalG1.ScalarMul(api, proof.H, point). + AddAssign(api, fminusfaG1) - // e([f(α) - f(a)]G₁, G₂).e([-H(α)]G₁, [α-a]G₂) ==? 1 + // e([f(α)-f(a)+aH(α)]G₁], G₂).e([-H(α)]G₁, [α]G₂) == 1 resPairing, _ := sw_bls24315.Pair( api, - []sw_bls24315.G1Affine{fminusfaG1, negH}, - []sw_bls24315.G2Affine{srs.G2[0], alphaMinusaG2}, + []sw_bls24315.G1Affine{totalG1, negH}, + []sw_bls24315.G2Affine{srs.G2[0], srs.G2[1]}, ) var one fields_bls24315.E24 diff --git a/std/commitments/kzg_bls24315/verifier_test.go b/std/commitments/kzg_bls24315/verifier_test.go index d331ce82ed..d2588cacee 100644 --- a/std/commitments/kzg_bls24315/verifier_test.go +++ b/std/commitments/kzg_bls24315/verifier_test.go @@ -69,17 +69,17 @@ func TestVerifierDynamic(t *testing.T) { } // commit to the polynomial - com, err := kzg.Commit(f, srs) + com, err := kzg.Commit(f, srs.Pk) assert.NoError(err) // create opening proof var point fr.Element point.SetRandom() - proof, err := kzg.Open(f, point, srs) + proof, err := kzg.Open(f, point, srs.Pk) assert.NoError(err) // check that the proof is correct - err = kzg.Verify(&com, &proof, point, srs) + err = kzg.Verify(&com, &proof, point, srs.Vk) if err != nil { t.Fatal(err) } @@ -98,26 +98,23 @@ func TestVerifierDynamic(t *testing.T) { witness.S = point.String() - witness.VerifKey.G1.X = srs.G1[0].X.String() - witness.VerifKey.G1.Y = srs.G1[0].Y.String() - - witness.VerifKey.G2[0].X.B0.A0 = srs.G2[0].X.B0.A0.String() - witness.VerifKey.G2[0].X.B0.A1 = srs.G2[0].X.B0.A1.String() - witness.VerifKey.G2[0].X.B1.A0 = srs.G2[0].X.B1.A0.String() - witness.VerifKey.G2[0].X.B1.A1 = srs.G2[0].X.B1.A1.String() - witness.VerifKey.G2[0].Y.B0.A0 = srs.G2[0].Y.B0.A0.String() - witness.VerifKey.G2[0].Y.B0.A1 = srs.G2[0].Y.B0.A1.String() - witness.VerifKey.G2[0].Y.B1.A0 = srs.G2[0].Y.B1.A0.String() - witness.VerifKey.G2[0].Y.B1.A1 = srs.G2[0].Y.B1.A1.String() - - witness.VerifKey.G2[1].X.B0.A0 = srs.G2[1].X.B0.A0.String() - witness.VerifKey.G2[1].X.B0.A1 = srs.G2[1].X.B0.A1.String() - witness.VerifKey.G2[1].X.B1.A0 = srs.G2[1].X.B1.A0.String() - witness.VerifKey.G2[1].X.B1.A1 = srs.G2[1].X.B1.A1.String() - witness.VerifKey.G2[1].Y.B0.A0 = srs.G2[1].Y.B0.A0.String() - witness.VerifKey.G2[1].Y.B0.A1 = srs.G2[1].Y.B0.A1.String() - witness.VerifKey.G2[1].Y.B1.A0 = srs.G2[1].Y.B1.A0.String() - witness.VerifKey.G2[1].Y.B1.A1 = srs.G2[1].Y.B1.A1.String() + witness.VerifKey.G2[0].X.B0.A0 = srs.Vk.G2[0].X.B0.A0.String() + witness.VerifKey.G2[0].X.B0.A1 = srs.Vk.G2[0].X.B0.A1.String() + witness.VerifKey.G2[0].X.B1.A0 = srs.Vk.G2[0].X.B1.A0.String() + witness.VerifKey.G2[0].X.B1.A1 = srs.Vk.G2[0].X.B1.A1.String() + witness.VerifKey.G2[0].Y.B0.A0 = srs.Vk.G2[0].Y.B0.A0.String() + witness.VerifKey.G2[0].Y.B0.A1 = srs.Vk.G2[0].Y.B0.A1.String() + witness.VerifKey.G2[0].Y.B1.A0 = srs.Vk.G2[0].Y.B1.A0.String() + witness.VerifKey.G2[0].Y.B1.A1 = srs.Vk.G2[0].Y.B1.A1.String() + + witness.VerifKey.G2[1].X.B0.A0 = srs.Vk.G2[1].X.B0.A0.String() + witness.VerifKey.G2[1].X.B0.A1 = srs.Vk.G2[1].X.B0.A1.String() + witness.VerifKey.G2[1].X.B1.A0 = srs.Vk.G2[1].X.B1.A0.String() + witness.VerifKey.G2[1].X.B1.A1 = srs.Vk.G2[1].X.B1.A1.String() + witness.VerifKey.G2[1].Y.B0.A0 = srs.Vk.G2[1].Y.B0.A0.String() + witness.VerifKey.G2[1].Y.B0.A1 = srs.Vk.G2[1].Y.B0.A1.String() + witness.VerifKey.G2[1].Y.B1.A0 = srs.Vk.G2[1].Y.B1.A0.String() + witness.VerifKey.G2[1].Y.B1.A1 = srs.Vk.G2[1].Y.B1.A1.String() // check if the circuit is solved var circuit verifierCircuit @@ -141,8 +138,6 @@ func TestVerifier(t *testing.T) { witness.Proof.ClaimedValue = "10347231107172233075459792371577505115223937655290126532055162077965558980163" witness.S = "4321" - witness.VerifKey.G1.X = "34223510504517033132712852754388476272837911830964394866541204856091481856889569724484362330263" - witness.VerifKey.G1.Y = "24215295174889464585413596429561903295150472552154479431771837786124301185073987899223459122783" witness.VerifKey.G2[0].X.B0.A0 = "24614737899199071964341749845083777103809664018538138889239909664991294445469052467064654073699" witness.VerifKey.G2[0].X.B0.A1 = "17049297748993841127032249156255993089778266476087413538366212660716380683149731996715975282972" diff --git a/std/evmprecompiles/01-ecrecover.go b/std/evmprecompiles/01-ecrecover.go new file mode 100644 index 0000000000..2e7dea01f3 --- /dev/null +++ b/std/evmprecompiles/01-ecrecover.go @@ -0,0 +1,86 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/emulated" +) + +// ECRecover implements [ECRECOVER] precompile contract at address 0x01. +// +// [ECRECOVER]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/ecrecover/index.html +func ECRecover(api frontend.API, msg emulated.Element[emulated.Secp256k1Fr], + v frontend.Variable, r, s emulated.Element[emulated.Secp256k1Fr], + strictRange frontend.Variable) *sw_emulated.AffinePoint[emulated.Secp256k1Fp] { + // EVM uses v \in {27, 28}, but everyone else v >= 0. Convert back + v = api.Sub(v, 27) + var emfp emulated.Secp256k1Fp + var emfr emulated.Secp256k1Fr + fpField, err := emulated.NewField[emulated.Secp256k1Fp](api) + if err != nil { + panic(fmt.Sprintf("new field: %v", err)) + } + frField, err := emulated.NewField[emulated.Secp256k1Fr](api) + if err != nil { + panic(fmt.Sprintf("new field: %v", err)) + } + // with the encoding we may have that r,s < 2*Fr (i.e. not r,s < Fr). Apply more thorough checks. + frField.AssertIsLessOrEqual(&r, frField.Modulus()) + // Ethereum Yellow Paper defines that the check for s should be more strict + // when checking transaction signatures (Appendix F). There we should check + // that s <= (Fr-1)/2 + halfFr := new(big.Int).Sub(emfr.Modulus(), big.NewInt(1)) + halfFr.Div(halfFr, big.NewInt(2)) + bound := frField.Select(strictRange, frField.NewElement(halfFr), frField.Modulus()) + frField.AssertIsLessOrEqual(&s, bound) + + curve, err := sw_emulated.New[emulated.Secp256k1Fp, emulated.Secp256k1Fr](api, sw_emulated.GetSecp256k1Params()) + if err != nil { + panic(fmt.Sprintf("new curve: %v", err)) + } + // we cannot directly use the field emulation hint calling wrappers as we work between two fields. + Rlimbs, err := api.Compiler().NewHint(recoverPointHint, 2*int(emfp.NbLimbs()), recoverPointHintArgs(v, r)...) + if err != nil { + panic(fmt.Sprintf("point hint: %v", err)) + } + R := sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: *fpField.NewElement(Rlimbs[0:emfp.NbLimbs()]), + Y: *fpField.NewElement(Rlimbs[emfp.NbLimbs() : 2*emfp.NbLimbs()]), + } + // we cannot directly use the field emulation hint calling wrappers as we work between two fields. + Plimbs, err := api.Compiler().NewHint(recoverPublicKeyHint, 2*int(emfp.NbLimbs()), recoverPublicKeyHintArgs(msg, v, r, s)...) + if err != nil { + panic(fmt.Sprintf("point hint: %v", err)) + } + P := sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: *fpField.NewElement(Plimbs[0:emfp.NbLimbs()]), + Y: *fpField.NewElement(Plimbs[emfp.NbLimbs() : 2*emfp.NbLimbs()]), + } + // check that len(v) = 2 + vbits := bits.ToBinary(api, v, bits.WithNbDigits(2)) + // check that Rx is correct: x = r+v[1]*fr + tmp := fpField.Select(vbits[1], fpField.NewElement(emfr.Modulus()), fpField.NewElement(0)) + rbits := frField.ToBits(&r) + rfp := fpField.FromBits(rbits...) + tmp = fpField.Add(rfp, tmp) + fpField.AssertIsEqual(tmp, &R.X) + // check that Ry is correct: oddity(y) = v[0] + Rynormal := fpField.Reduce(&R.Y) + Rybits := fpField.ToBits(Rynormal) + api.AssertIsEqual(vbits[0], Rybits[0]) + // compute rinv = r^{-1} mod fr + rinv := frField.Inverse(&r) + // compute u1 = -msg * rinv + u1 := frField.MulMod(&msg, rinv) + u1 = frField.Neg(u1) + // compute u2 = s * rinv + u2 := frField.MulMod(&s, rinv) + // check u1 * G + u2 R == P + C := curve.JointScalarMulBase(&R, u2, u1) + curve.AssertIsEqual(C, &P) + return &P +} diff --git a/std/evmprecompiles/01-ecrecover_test.go b/std/evmprecompiles/01-ecrecover_test.go new file mode 100644 index 0000000000..577a45b75b --- /dev/null +++ b/std/evmprecompiles/01-ecrecover_test.go @@ -0,0 +1,141 @@ +package evmprecompiles + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1/ecdsa" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +func TestSignForRecoverCorrectness(t *testing.T) { + sk, err := ecdsa.GenerateKey(rand.Reader) + if err != nil { + t.Fatal("generate", err) + } + pk := sk.PublicKey + msg := []byte("test") + _, r, s, err := sk.SignForRecover(msg, nil) + if err != nil { + t.Fatal("sign", err) + } + var sig ecdsa.Signature + r.FillBytes(sig.R[:fr.Bytes]) + s.FillBytes(sig.S[:fr.Bytes]) + sigM := sig.Bytes() + ok, err := pk.Verify(sigM, msg, nil) + if err != nil { + t.Fatal("verify", err) + } + if !ok { + t.Fatal("not verified") + } +} + +type ecrecoverCircuit struct { + Message emulated.Element[emulated.Secp256k1Fr] + V frontend.Variable + R emulated.Element[emulated.Secp256k1Fr] + S emulated.Element[emulated.Secp256k1Fr] + Strict frontend.Variable + Expected sw_emulated.AffinePoint[emulated.Secp256k1Fp] +} + +func (c *ecrecoverCircuit) Define(api frontend.API) error { + curve, err := sw_emulated.New[emulated.Secp256k1Fp, emulated.Secp256k1Fr](api, sw_emulated.GetSecp256k1Params()) + if err != nil { + return fmt.Errorf("new curve: %w", err) + } + res := ECRecover(api, c.Message, c.V, c.R, c.S, c.Strict) + curve.AssertIsEqual(&c.Expected, res) + return nil +} + +func testRoutineECRecover(t *testing.T, wantStrict bool) (circ, wit *ecrecoverCircuit, largeS bool) { + halfFr := new(big.Int).Sub(fr.Modulus(), big.NewInt(1)) + halfFr.Div(halfFr, big.NewInt(2)) + + sk, err := ecdsa.GenerateKey(rand.Reader) + if err != nil { + t.Fatal("generate", err) + } + pk := sk.PublicKey + msg := []byte("test") + var r, s *big.Int + var v uint + for { + v, r, s, err = sk.SignForRecover(msg, nil) + if err != nil { + t.Fatal("sign", err) + } + if !wantStrict || halfFr.Cmp(s) > 0 { + break + } + } + strict := 0 + if wantStrict { + strict = 1 + } + circuit := ecrecoverCircuit{} + witness := ecrecoverCircuit{ + Message: emulated.ValueOf[emulated.Secp256k1Fr](ecdsa.HashToInt(msg)), + V: v + 27, // EVM constant + R: emulated.ValueOf[emulated.Secp256k1Fr](r), + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Strict: strict, + Expected: sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.Y), + }, + } + return &circuit, &witness, halfFr.Cmp(s) <= 0 +} + +func TestECRecoverCircuitShortStrict(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness, _ := testRoutineECRecover(t, true) + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECRecoverCircuitShortLax(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness, _ := testRoutineECRecover(t, false) + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECRecoverCircuitShortMismatch(t *testing.T) { + assert := test.NewAssert(t) + halfFr := new(big.Int).Sub(fr.Modulus(), big.NewInt(1)) + halfFr.Div(halfFr, big.NewInt(2)) + var circuit, witness *ecrecoverCircuit + var largeS bool + for { + circuit, witness, largeS = testRoutineECRecover(t, false) + if largeS { + witness.Strict = 1 + break + } + } + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.Error(err) +} + +func TestECRecoverCircuitFull(t *testing.T) { + t.Skip("skipping very long test") + assert := test.NewAssert(t) + circuit, witness, _ := testRoutineECRecover(t, false) + assert.ProverSucceeded(circuit, witness, + test.NoFuzzing(), test.NoSerialization(), + test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BN254), + ) +} diff --git a/std/evmprecompiles/02-sha256.go b/std/evmprecompiles/02-sha256.go new file mode 100644 index 0000000000..ae0c33e733 --- /dev/null +++ b/std/evmprecompiles/02-sha256.go @@ -0,0 +1,3 @@ +package evmprecompiles + +// will implement later diff --git a/std/evmprecompiles/04-id.go b/std/evmprecompiles/04-id.go new file mode 100644 index 0000000000..d554d428c8 --- /dev/null +++ b/std/evmprecompiles/04-id.go @@ -0,0 +1,3 @@ +package evmprecompiles + +// not going to implement. It is trivial. diff --git a/std/evmprecompiles/05-expmod.go b/std/evmprecompiles/05-expmod.go new file mode 100644 index 0000000000..6b1eb16123 --- /dev/null +++ b/std/evmprecompiles/05-expmod.go @@ -0,0 +1 @@ +package evmprecompiles diff --git a/std/evmprecompiles/06-bnadd.go b/std/evmprecompiles/06-bnadd.go new file mode 100644 index 0000000000..ff6c397fc9 --- /dev/null +++ b/std/evmprecompiles/06-bnadd.go @@ -0,0 +1,21 @@ +package evmprecompiles + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +// ECAdd implements [ALT_BN128_ADD] precompile contract at address 0x06. +// +// [ALT_BN128_ADD]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-add +func ECAdd(api frontend.API, P, Q *sw_emulated.AffinePoint[emulated.BN254Fp]) *sw_emulated.AffinePoint[emulated.BN254Fp] { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + panic(err) + } + // Check that P and Q are on the curve (done in the zkEVM ⚠️ ) + // We use AddUnified because P can be equal to Q, -Q and either or both can be (0,0) + res := curve.AddUnified(P, Q) + return res +} diff --git a/std/evmprecompiles/07-bnmul.go b/std/evmprecompiles/07-bnmul.go new file mode 100644 index 0000000000..eb1c02ccda --- /dev/null +++ b/std/evmprecompiles/07-bnmul.go @@ -0,0 +1,20 @@ +package evmprecompiles + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +// ECMul implements [ALT_BN128_MUL] precompile contract at address 0x07. +// +// [ALT_BN128_MUL]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-mul +func ECMul(api frontend.API, P *sw_emulated.AffinePoint[emulated.BN254Fp], u *emulated.Element[emulated.BN254Fr]) *sw_emulated.AffinePoint[emulated.BN254Fp] { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + panic(err) + } + // Check that P is on the curve (done in the zkEVM ⚠️ ) + res := curve.ScalarMul(P, u) + return res +} diff --git a/std/evmprecompiles/08-bnpairing.go b/std/evmprecompiles/08-bnpairing.go new file mode 100644 index 0000000000..01dbbab99c --- /dev/null +++ b/std/evmprecompiles/08-bnpairing.go @@ -0,0 +1,26 @@ +package evmprecompiles + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" +) + +// ECPair implements [ALT_BN128_PAIRING_CHECK] precompile contract at address 0x08. +// +// [ALT_BN128_PAIRING_CHECK]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-pairing-check +func ECPair(api frontend.API, P []*sw_bn254.G1Affine, Q []*sw_bn254.G2Affine) { + pair, err := sw_bn254.NewPairing(api) + if err != nil { + panic(err) + } + // 1- Check that Pᵢ are on G1 (done in the zkEVM ⚠️ ) + // 2- Check that Qᵢ are on G2 + for i := 0; i < len(Q); i++ { + pair.AssertIsOnG2(Q[i]) + } + + // 3- Check that ∏ᵢ e(Pᵢ, Qᵢ) == 1 + if err := pair.PairingCheck(P, Q); err != nil { + panic(err) + } +} diff --git a/std/evmprecompiles/bn_test.go b/std/evmprecompiles/bn_test.go new file mode 100644 index 0000000000..f5efe1baa3 --- /dev/null +++ b/std/evmprecompiles/bn_test.go @@ -0,0 +1,168 @@ +package evmprecompiles + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +type ecaddCircuit struct { + X0 sw_emulated.AffinePoint[emulated.BN254Fp] + X1 sw_emulated.AffinePoint[emulated.BN254Fp] + Expected sw_emulated.AffinePoint[emulated.BN254Fp] +} + +func (c *ecaddCircuit) Define(api frontend.API) error { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + return err + } + res := ECAdd(api, &c.X0, &c.X1) + curve.AssertIsEqual(res, &c.Expected) + return nil +} + +func testRoutineECAdd() (circ, wit frontend.Circuit) { + _, _, G, _ := bn254.Generators() + var u, v fr.Element + u.SetRandom() + v.SetRandom() + var P, Q bn254.G1Affine + P.ScalarMultiplication(&G, u.BigInt(new(big.Int))) + Q.ScalarMultiplication(&G, v.BigInt(new(big.Int))) + var expected bn254.G1Affine + expected.Add(&P, &Q) + circuit := ecaddCircuit{} + witness := ecaddCircuit{ + X0: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](P.X), + Y: emulated.ValueOf[emulated.BN254Fp](P.Y), + }, + X1: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Q.X), + Y: emulated.ValueOf[emulated.BN254Fp](Q.Y), + }, + Expected: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](expected.X), + Y: emulated.ValueOf[emulated.BN254Fp](expected.Y), + }, + } + return &circuit, &witness +} + +func TestECAddCircuitShort(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness := testRoutineECAdd() + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECAddCircuitFull(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness := testRoutineECAdd() + assert.ProverSucceeded(circuit, witness, + test.NoFuzzing(), test.NoSerialization(), + test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BN254), + ) +} + +type ecmulCircuit struct { + X0 sw_emulated.AffinePoint[emulated.BN254Fp] + U emulated.Element[emulated.BN254Fr] + Expected sw_emulated.AffinePoint[emulated.BN254Fp] +} + +func (c *ecmulCircuit) Define(api frontend.API) error { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + return err + } + res := ECMul(api, &c.X0, &c.U) + curve.AssertIsEqual(res, &c.Expected) + return nil +} + +func testRoutineECMul(t *testing.T) (circ, wit frontend.Circuit) { + _, _, G, _ := bn254.Generators() + var u, v fr.Element + u.SetRandom() + v.SetRandom() + var P bn254.G1Affine + P.ScalarMultiplication(&G, u.BigInt(new(big.Int))) + var expected bn254.G1Affine + expected.ScalarMultiplication(&P, v.BigInt(new(big.Int))) + circuit := ecmulCircuit{} + witness := ecmulCircuit{ + X0: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](P.X), + Y: emulated.ValueOf[emulated.BN254Fp](P.Y), + }, + U: emulated.ValueOf[emulated.BN254Fr](v), + Expected: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](expected.X), + Y: emulated.ValueOf[emulated.BN254Fp](expected.Y), + }, + } + return &circuit, &witness +} + +func TestECMulCircuitShort(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness := testRoutineECMul(t) + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECMulCircuitFull(t *testing.T) { + t.Skip("skipping very long test") + assert := test.NewAssert(t) + circuit, witness := testRoutineECMul(t) + assert.ProverSucceeded(circuit, witness, + test.NoFuzzing(), test.NoSerialization(), + test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BN254), + ) +} + +type ecpairCircuit struct { + P [2]sw_bn254.G1Affine + Q [2]sw_bn254.G2Affine +} + +func (c *ecpairCircuit) Define(api frontend.API) error { + P := []*sw_bn254.G1Affine{&c.P[0], &c.P[1]} + Q := []*sw_bn254.G2Affine{&c.Q[0], &c.Q[1]} + ECPair(api, P, Q) + return nil +} + +func TestECPairCircuitShort(t *testing.T) { + assert := test.NewAssert(t) + _, _, p1, q1 := bn254.Generators() + + var u, v fr.Element + u.SetRandom() + v.SetRandom() + + p1.ScalarMultiplication(&p1, u.BigInt(new(big.Int))) + q1.ScalarMultiplication(&q1, v.BigInt(new(big.Int))) + + var p2 bn254.G1Affine + var q2 bn254.G2Affine + p2.Neg(&p1) + q2.Set(&q1) + + err := test.IsSolved(&ecpairCircuit{}, &ecpairCircuit{ + P: [2]sw_bn254.G1Affine{sw_bn254.NewG1Affine(p1), sw_bn254.NewG1Affine(p2)}, + Q: [2]sw_bn254.G2Affine{sw_bn254.NewG2Affine(q1), sw_bn254.NewG2Affine(q2)}, + }, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/evmprecompiles/compose.go b/std/evmprecompiles/compose.go new file mode 100644 index 0000000000..c1ffaff2aa --- /dev/null +++ b/std/evmprecompiles/compose.go @@ -0,0 +1,37 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" +) + +func recompose(inputs []*big.Int, nbBits uint) *big.Int { + res := new(big.Int) + if len(inputs) == 0 { + return res + } + for i := range inputs { + res.Lsh(res, nbBits) + res.Add(res, inputs[len(inputs)-i-1]) + } + return res +} + +func decompose(input *big.Int, nbBits uint, res []*big.Int) error { + // limb modulus + if input.BitLen() > len(res)*int(nbBits) { + return fmt.Errorf("decomposed integer does not fit into res") + } + for _, r := range res { + if r == nil { + return fmt.Errorf("result slice element uninitalized") + } + } + base := new(big.Int).Lsh(big.NewInt(1), nbBits) + tmp := new(big.Int).Set(input) + for i := 0; i < len(res); i++ { + res[i].Mod(tmp, base) + tmp.Rsh(tmp, nbBits) + } + return nil +} diff --git a/std/evmprecompiles/doc.go b/std/evmprecompiles/doc.go new file mode 100644 index 0000000000..7c515eaa51 --- /dev/null +++ b/std/evmprecompiles/doc.go @@ -0,0 +1,18 @@ +// Package evmprecompiles implements the Ethereum VM precompile contracts. +// +// This package collects all the precompile functions into a single location for +// easier integration. The main functionality is implemented elsewhere. This +// package right now implements: +// 1. ECRECOVER ✅ -- function [ECRecover] +// 2. SHA256 ❌ -- in progress +// 3. RIPEMD160 ❌ -- postponed +// 4. ID ❌ -- trivial to implement without function +// 5. EXPMOD ❌ -- in progress +// 6. BN_ADD ✅ -- function [ECAdd] +// 7. BN_MUL ✅ -- function [ECMul] +// 8. SNARKV ✅ -- function [ECPair] +// 9. BLAKE2F ❌ -- postponed +// +// This package uses local representation for the arguments. It is up to the +// user to instantiate corresponding types from their application-specific data. +package evmprecompiles diff --git a/std/evmprecompiles/hints.go b/std/evmprecompiles/hints.go new file mode 100644 index 0000000000..0ebc231ae0 --- /dev/null +++ b/std/evmprecompiles/hints.go @@ -0,0 +1,97 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/secp256k1/ecdsa" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all the hints used in this package. +func GetHints() []solver.Hint { + return []solver.Hint{recoverPointHint, recoverPublicKeyHint} +} + +func recoverPointHintArgs(v frontend.Variable, r emulated.Element[emulated.Secp256k1Fr]) []frontend.Variable { + args := []frontend.Variable{v} + args = append(args, r.Limbs...) + return args +} + +func recoverPointHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + var emfp emulated.Secp256k1Fp + if len(inputs) != int(emfp.NbLimbs())+1 { + return fmt.Errorf("expected input %d limbs got %d", emfp.NbLimbs()+1, len(inputs)) + } + if !inputs[0].IsInt64() { + return fmt.Errorf("first input supposed to be in [0,3]") + } + if len(outputs) != 2*int(emfp.NbLimbs()) { + return fmt.Errorf("expected output %d limbs got %d", 2*emfp.NbLimbs(), len(outputs)) + } + v := inputs[0].Uint64() + r := recompose(inputs[1:], emfp.BitsPerLimb()) + P, err := ecdsa.RecoverP(uint(v), r) + if err != nil { + return fmt.Errorf("recover: %s", err) + } + if err := decompose(P.X.BigInt(new(big.Int)), emfp.BitsPerLimb(), outputs[0:emfp.NbLimbs()]); err != nil { + return fmt.Errorf("decompose x: %w", err) + } + if err := decompose(P.Y.BigInt(new(big.Int)), emfp.BitsPerLimb(), outputs[emfp.NbLimbs():]); err != nil { + return fmt.Errorf("decompose y: %w", err) + } + return nil +} + +func recoverPublicKeyHintArgs(msg emulated.Element[emulated.Secp256k1Fr], + v frontend.Variable, r, s emulated.Element[emulated.Secp256k1Fr]) []frontend.Variable { + args := msg.Limbs + args = append(args, v) + args = append(args, r.Limbs...) + args = append(args, s.Limbs...) + return args +} + +func recoverPublicKeyHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + // message -nb limbs + // then v - 1 + // r -- nb limbs + // s -- nb limbs + // return 2x nb limbs + var emfr emulated.Secp256k1Fr + var emfp emulated.Secp256k1Fp + if len(inputs) != int(emfr.NbLimbs())*3+1 { + return fmt.Errorf("expected %d limbs got %d", emfr.NbLimbs()*3+1, len(inputs)) + } + if !inputs[emfr.NbLimbs()].IsInt64() { + return fmt.Errorf("second input input must be in [0,3]") + } + if len(outputs) != 2*int(emfp.NbLimbs()) { + return fmt.Errorf("expected output %d limbs got %d", 2*emfp.NbLimbs(), len(outputs)) + } + msg := recompose(inputs[:emfr.NbLimbs()], emfr.BitsPerLimb()) + v := inputs[emfr.NbLimbs()].Uint64() + r := recompose(inputs[emfr.NbLimbs()+1:2*emfr.NbLimbs()+1], emfr.BitsPerLimb()) + s := recompose(inputs[2*emfr.NbLimbs()+1:3*emfr.NbLimbs()+1], emfr.BitsPerLimb()) + var pk ecdsa.PublicKey + if err := pk.RecoverFrom(msg.Bytes(), uint(v), r, s); err != nil { + return fmt.Errorf("recover public key: %w", err) + } + Px := pk.A.X.BigInt(new(big.Int)) + Py := pk.A.Y.BigInt(new(big.Int)) + if err := decompose(Px, emfp.BitsPerLimb(), outputs[0:emfp.NbLimbs()]); err != nil { + return fmt.Errorf("decompose x: %w", err) + } + if err := decompose(Py, emfp.BitsPerLimb(), outputs[emfp.NbLimbs():2*emfp.NbLimbs()]); err != nil { + return fmt.Errorf("decompose y: %w", err) + } + return nil +} diff --git a/std/fiat-shamir/settings.go b/std/fiat-shamir/settings.go index 7b1f2a31a7..146a64355e 100644 --- a/std/fiat-shamir/settings.go +++ b/std/fiat-shamir/settings.go @@ -9,7 +9,7 @@ type Settings struct { Transcript *Transcript Prefix string BaseChallenges []frontend.Variable - Hash hash.Hash + Hash hash.FieldHasher } func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings { @@ -20,7 +20,7 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro } } -func WithHash(hash hash.Hash, baseChallenges ...frontend.Variable) Settings { +func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings { return Settings{ BaseChallenges: baseChallenges, Hash: hash, diff --git a/std/fiat-shamir/transcript.go b/std/fiat-shamir/transcript.go index 59fcb74ba6..63fe8e786d 100644 --- a/std/fiat-shamir/transcript.go +++ b/std/fiat-shamir/transcript.go @@ -34,7 +34,7 @@ var ( // Transcript handles the creation of challenges for Fiat Shamir. type Transcript struct { // hash function that is used. - h hash.Hash + h hash.FieldHasher challenges map[string]challenge previous *challenge @@ -53,7 +53,7 @@ type challenge struct { // NewTranscript returns a new transcript. // h is the hash function that is used to compute the challenges. // challenges are the name of the challenges. The order is important. -func NewTranscript(api frontend.API, h hash.Hash, challengesID ...string) Transcript { +func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID ...string) Transcript { n := len(challengesID) t := Transcript{ challenges: make(map[string]challenge, n), diff --git a/std/gkr/gkr_test.go b/std/gkr/gkr_test.go index fa6646067d..4149ea8022 100644 --- a/std/gkr/gkr_test.go +++ b/std/gkr/gkr_test.go @@ -111,7 +111,7 @@ func (c *GkrVerifierCircuit) Define(api frontend.API) error { } assignment := makeInOutAssignment(testCase.Circuit, c.Input, c.Output) - var hsh hash.Hash + var hsh hash.FieldHasher if c.ToFail { hsh = NewMessageCounter(api, 1, 1) } else { @@ -415,7 +415,7 @@ func SliceEqual[T comparable](expected, seen []T) bool { type HashDescription map[string]interface{} -func HashFromDescription(api frontend.API, d HashDescription) (hash.Hash, error) { +func HashFromDescription(api frontend.API, d HashDescription) (hash.FieldHasher, error) { if _type, ok := d["type"]; ok { switch _type { case "const": @@ -457,13 +457,13 @@ func (m *MessageCounter) Reset() { m.state = m.startState } -func NewMessageCounter(api frontend.API, startState, step int) hash.Hash { +func NewMessageCounter(api frontend.API, startState, step int) hash.FieldHasher { transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step), api: api} return transcript } -func NewMessageCounterGenerator(startState, step int) func(frontend.API) hash.Hash { - return func(api frontend.API) hash.Hash { +func NewMessageCounterGenerator(startState, step int) func(frontend.API) hash.FieldHasher { + return func(api frontend.API) hash.FieldHasher { return NewMessageCounter(api, startState, step) } } diff --git a/std/groth16_bls12377/verifier.go b/std/groth16_bls12377/verifier.go index 6a43e585a8..e2f7f37338 100644 --- a/std/groth16_bls12377/verifier.go +++ b/std/groth16_bls12377/verifier.go @@ -1,5 +1,5 @@ /* -Copyright © 2020 ConsenSys +Copyright 2020 ConsenSys Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,10 +22,10 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/backend/groth16" + groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377" "github.com/consensys/gnark/frontend" - groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - "github.com/consensys/gnark/std/algebra/fields_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" ) // Proof represents a Groth16 proof @@ -49,6 +49,7 @@ type VerifyingKey struct { // [Kvk]1 G1 struct { K []sw_bls12377.G1Affine // The indexes correspond to the public wires + } } @@ -57,7 +58,8 @@ type VerifyingKey struct { // publicInputs do NOT contain the ONE_WIRE func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []frontend.Variable) { if len(vk.G1.K) == 0 { - panic("innver verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + panic("inner verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + } // compute kSum = Σx.[Kvk(t)]1 @@ -71,14 +73,15 @@ func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []front var ki sw_bls12377.G1Affine ki.ScalarMul(api, vk.G1.K[k+1], v) kSum.AddAssign(api, ki) + } // compute e(Σx.[Kvk(t)]1, -[γ]2) * e(Krs,δ) * e(Ar,Bs) - ml, _ := sw_bls12377.MillerLoop(api, []sw_bls12377.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls12377.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) - pairing := sw_bls12377.FinalExponentiation(api, ml) + pairing, _ := sw_bls12377.Pair(api, []sw_bls12377.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls12377.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) // vk.E must be equal to pairing vk.E.AssertIsEqual(api, pairing) + } // Assign values to the "in-circuit" VerifyingKey from a "out-of-circuit" VerifyingKey @@ -86,21 +89,51 @@ func (vk *VerifyingKey) Assign(_ovk groth16.VerifyingKey) { ovk, ok := _ovk.(*groth16_bls12377.VerifyingKey) if !ok { panic("expected *groth16_bls12377.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + } e, err := bls12377.Pair([]bls12377.G1Affine{ovk.G1.Alpha}, []bls12377.G2Affine{ovk.G2.Beta}) if err != nil { panic(err) + } vk.E.Assign(&e) vk.G1.K = make([]sw_bls12377.G1Affine, len(ovk.G1.K)) for i := 0; i < len(ovk.G1.K); i++ { vk.G1.K[i].Assign(&ovk.G1.K[i]) + } var deltaNeg, gammaNeg bls12377.G2Affine deltaNeg.Neg(&ovk.G2.Delta) gammaNeg.Neg(&ovk.G2.Gamma) vk.G2.DeltaNeg.Assign(&deltaNeg) vk.G2.GammaNeg.Assign(&gammaNeg) + +} + +// Allocate memory for the "in-circuit" VerifyingKey +// This is exposed so that the slices in the structure can be allocated +// before calling frontend.Compile(). +func (vk *VerifyingKey) Allocate(_ovk groth16.VerifyingKey) { + ovk, ok := _ovk.(*groth16_bls12377.VerifyingKey) + if !ok { + panic("expected *groth16_bls12377.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + + } + vk.G1.K = make([]sw_bls12377.G1Affine, len(ovk.G1.K)) + +} + +// Assign the proof values of Groth16 +func (proof *Proof) Assign(_oproof groth16.Proof) { + oproof, ok := _oproof.(*groth16_bls12377.Proof) + if !ok { + panic("expected *groth16_bls12377.Proof, got " + reflect.TypeOf(oproof).String()) + + } + proof.Ar.Assign(&oproof.Ar) + proof.Krs.Assign(&oproof.Krs) + proof.Bs.Assign(&oproof.Bs) + } diff --git a/std/groth16_bls12377/verifier_test.go b/std/groth16_bls12377/verifier_test.go index 10bcd53d8b..282984139e 100644 --- a/std/groth16_bls12377/verifier_test.go +++ b/std/groth16_bls12377/verifier_test.go @@ -1,5 +1,5 @@ /* -Copyright © 2020 ConsenSys +Copyright 2020 ConsenSys Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,13 +22,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" - cs_bls12377 "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - "github.com/consensys/gnark/std/algebra/sw_bls12377" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) @@ -47,125 +46,161 @@ func (circuit *mimcCircuit) Define(api frontend.API) error { mimc, err := mimc.NewMiMC(api) if err != nil { return err + } mimc.Write(circuit.PreImage) api.AssertIsEqual(mimc.Sum(), circuit.Hash) return nil + +} + +// Calculate the expected output of MIMC through plain invocation +func preComputeMimc(preImage frontend.Variable) interface{} { + var expectedY fr.Element + expectedY.SetInterface(preImage) + // calc MiMC + goMimc := hash.MIMC_BLS12_377.New() + goMimc.Write(expectedY.Marshal()) + expectedh := goMimc.Sum(nil) + return expectedh + +} + +type verifierCircuit struct { + InnerProof Proof + InnerVk VerifyingKey + Hash frontend.Variable } -// Prepare the data for the inner proof. -// Returns the public inputs string of the inner proof -func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, proof *groth16_bls12377.Proof) { +func (circuit *verifierCircuit) Define(api frontend.API) error { + // create the verifier cs + Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + + return nil + +} + +func TestVerifier(t *testing.T) { // create a mock cs: knowing the preimage of a hash using mimc - var circuit mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, &circuit) + var MimcCircuit mimcCircuit + r1cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, &MimcCircuit) if err != nil { t.Fatal(err) - } - // build the witness - var assignment mimcCircuit - assignment.PreImage = preImage - assignment.Hash = publicHash + } - witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS12_377.ScalarField()) if err != nil { t.Fatal(err) + } - publicWitness, err := witness.Public() + innerPk, innerVk, err := groth16.Setup(r1cs) if err != nil { t.Fatal(err) + } - // generate the data to return for the bls12377 proof - var pk groth16_bls12377.ProvingKey - err = groth16_bls12377.Setup(r1cs.(*cs_bls12377.R1CS), &pk, vk) + proof, err := groth16.Prove(r1cs, innerPk, pre_witness) if err != nil { t.Fatal(err) + } - _proof, err := groth16_bls12377.Prove(r1cs.(*cs_bls12377.R1CS), &pk, witness.Vector().(fr.Vector), backend.ProverConfig{}) + publicWitness, err := pre_witness.Public() if err != nil { t.Fatal(err) + } - proof.Ar = _proof.Ar - proof.Bs = _proof.Bs - proof.Krs = _proof.Krs - // before returning verifies that the proof passes on bls12377 - if err := groth16_bls12377.Verify(proof, vk, publicWitness.Vector().(fr.Vector)); err != nil { + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { t.Fatal(err) + } -} + var circuit verifierCircuit + circuit.InnerVk.Allocate(innerVk) -type verifierCircuit struct { - InnerProof Proof - InnerVk VerifyingKey - Hash frontend.Variable -} + var witness verifierCircuit + witness.InnerProof.Assign(proof) + witness.InnerVk.Assign(innerVk) + witness.Hash = preComputeMimc(preImage) -func (circuit *verifierCircuit) Define(api frontend.API) error { - // create the verifier cs - Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + assert := test.NewAssert(t) + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761), test.WithBackends(backend.GROTH16)) - return nil } -func TestVerifier(t *testing.T) { +func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls12377.VerifyingKey - var innerProof groth16_bls12377.Proof - generateBls12377InnerProof(t, &innerVk, &innerProof) // get public inputs of the inner proof + // create a mock cs: knowing the preimage of a hash using mimc + var MimcCircuit mimcCircuit + _r1cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, &MimcCircuit) + if err != nil { + b.Fatal(err) - // create an empty cs - var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls12377.G1Affine, len(innerVk.G1.K)) + } - // create assignment, the private part consists of the proof, - // the public part is exactly the public part of the inner proof, - // up to the renaming of the inner ONE_WIRE to not conflict with the one wire of the outer proof. - var witness verifierCircuit - witness.InnerProof.Ar.Assign(&innerProof.Ar) - witness.InnerProof.Krs.Assign(&innerProof.Krs) - witness.InnerProof.Bs.Assign(&innerProof.Bs) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS12_377.ScalarField()) + if err != nil { + b.Fatal(err) - witness.InnerVk.Assign(&innerVk) - witness.Hash = publicHash + } - // verifies the cs - assert := test.NewAssert(t) + innerPk, innerVk, err := groth16.Setup(_r1cs) + if err != nil { + b.Fatal(err) - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) -} + } -func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls12377.VerifyingKey - var innerProof groth16_bls12377.Proof - generateBls12377InnerProof(nil, &innerVk, &innerProof) // get public inputs of the inner proof + proof, err := groth16.Prove(_r1cs, innerPk, pre_witness) + if err != nil { + b.Fatal(err) + + } + + publicWitness, err := pre_witness.Public() + if err != nil { + b.Fatal(err) + + } + + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { + b.Fatal(err) + + } - // create an empty cs var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls12377.G1Affine, len(innerVk.G1.K)) + circuit.InnerVk.Allocate(innerVk) var ccs constraint.ConstraintSystem - var err error b.ResetTimer() for i := 0; i < b.N; i++ { ccs, err = frontend.Compile(ecc.BW6_761.ScalarField(), r1cs.NewBuilder, &circuit) if err != nil { b.Fatal(err) + } + } + b.Log(ccs.GetNbConstraints()) + } var tVariable reflect.Type func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() + } diff --git a/std/groth16_bls24315/verifier.go b/std/groth16_bls24315/verifier.go index 09487acbb3..8e474a8979 100644 --- a/std/groth16_bls24315/verifier.go +++ b/std/groth16_bls24315/verifier.go @@ -22,10 +22,10 @@ import ( bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark/backend/groth16" + groth16_bls24315 "github.com/consensys/gnark/backend/groth16/bls24-315" "github.com/consensys/gnark/frontend" - groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - "github.com/consensys/gnark/std/algebra/fields_bls24315" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" ) // Proof represents a Groth16 proof @@ -49,6 +49,7 @@ type VerifyingKey struct { // [Kvk]1 G1 struct { K []sw_bls24315.G1Affine // The indexes correspond to the public wires + } } @@ -57,7 +58,8 @@ type VerifyingKey struct { // publicInputs do NOT contain the ONE_WIRE func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []frontend.Variable) { if len(vk.G1.K) == 0 { - panic("innver verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + panic("inner verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + } // compute kSum = Σx.[Kvk(t)]1 @@ -71,11 +73,11 @@ func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []front var ki sw_bls24315.G1Affine ki.ScalarMul(api, vk.G1.K[k+1], v) kSum.AddAssign(api, ki) + } // compute e(Σx.[Kvk(t)]1, -[γ]2) * e(Krs,δ) * e(Ar,Bs) - ml, _ := sw_bls24315.MillerLoop(api, []sw_bls24315.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls24315.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) - pairing := sw_bls24315.FinalExponentiation(api, ml) + pairing, _ := sw_bls24315.Pair(api, []sw_bls24315.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls24315.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) // vk.E must be equal to pairing vk.E.AssertIsEqual(api, pairing) @@ -87,21 +89,51 @@ func (vk *VerifyingKey) Assign(_ovk groth16.VerifyingKey) { ovk, ok := _ovk.(*groth16_bls24315.VerifyingKey) if !ok { panic("expected *groth16_bls24315.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + } e, err := bls24315.Pair([]bls24315.G1Affine{ovk.G1.Alpha}, []bls24315.G2Affine{ovk.G2.Beta}) if err != nil { panic(err) + } vk.E.Assign(&e) vk.G1.K = make([]sw_bls24315.G1Affine, len(ovk.G1.K)) for i := 0; i < len(ovk.G1.K); i++ { vk.G1.K[i].Assign(&ovk.G1.K[i]) + } var deltaNeg, gammaNeg bls24315.G2Affine deltaNeg.Neg(&ovk.G2.Delta) gammaNeg.Neg(&ovk.G2.Gamma) vk.G2.DeltaNeg.Assign(&deltaNeg) vk.G2.GammaNeg.Assign(&gammaNeg) + +} + +// Allocate memory for the "in-circuit" VerifyingKey +// This is exposed so that the slices in the structure can be allocated +// before calling frontend.Compile(). +func (vk *VerifyingKey) Allocate(_ovk groth16.VerifyingKey) { + ovk, ok := _ovk.(*groth16_bls24315.VerifyingKey) + if !ok { + panic("expected *groth16_bls24315.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + + } + vk.G1.K = make([]sw_bls24315.G1Affine, len(ovk.G1.K)) + +} + +// Assign the proof values of Groth16 +func (proof *Proof) Assign(_oproof groth16.Proof) { + oproof, ok := _oproof.(*groth16_bls24315.Proof) + if !ok { + panic("expected *groth16_bls24315.Proof, got " + reflect.TypeOf(oproof).String()) + + } + proof.Ar.Assign(&oproof.Ar) + proof.Krs.Assign(&oproof.Krs) + proof.Bs.Assign(&oproof.Bs) + } diff --git a/std/groth16_bls24315/verifier_test.go b/std/groth16_bls24315/verifier_test.go index 7144014e8d..c3a5a9f3af 100644 --- a/std/groth16_bls24315/verifier_test.go +++ b/std/groth16_bls24315/verifier_test.go @@ -22,20 +22,16 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" - cs_bls24315 "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) -//-------------------------------------------------------------------- -// utils - const ( preImage = "4992816046196248432836492760315135318126925090839638585255611512962528270024" publicHash = "4875439939758844840941638351757981379945701574516438614845550995673793857363" @@ -50,121 +46,161 @@ func (circuit *mimcCircuit) Define(api frontend.API) error { mimc, err := mimc.NewMiMC(api) if err != nil { return err + } mimc.Write(circuit.PreImage) api.AssertIsEqual(mimc.Sum(), circuit.Hash) return nil + +} + +// Calculate the expected output of MIMC through plain invocation +func preComputeMimc(preImage frontend.Variable) interface{} { + var expectedY fr.Element + expectedY.SetInterface(preImage) + // calc MiMC + goMimc := hash.MIMC_BLS24_315.New() + goMimc.Write(expectedY.Marshal()) + expectedh := goMimc.Sum(nil) + return expectedh + +} + +type verifierCircuit struct { + InnerProof Proof + InnerVk VerifyingKey + Hash frontend.Variable } -// Prepare the data for the inner proof. -// Returns the public inputs string of the inner proof -func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, proof *groth16_bls24315.Proof) { +func (circuit *verifierCircuit) Define(api frontend.API) error { + // create the verifier cs + Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + + return nil + +} + +func TestVerifier(t *testing.T) { // create a mock cs: knowing the preimage of a hash using mimc - var circuit, assignment mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS24_315.ScalarField(), r1cs.NewBuilder, &circuit) + var MimcCircuit mimcCircuit + r1cs, err := frontend.Compile(ecc.BLS24_315.ScalarField(), r1cs.NewBuilder, &MimcCircuit) if err != nil { t.Fatal(err) - } - assignment.PreImage = preImage - assignment.Hash = publicHash + } - witness, err := frontend.NewWitness(&assignment, ecc.BLS24_315.ScalarField()) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS24_315.ScalarField()) if err != nil { t.Fatal(err) + } - publicWitness, err := witness.Public() + innerPk, innerVk, err := groth16.Setup(r1cs) if err != nil { t.Fatal(err) + } - // generate the data to return for the bls24315 proof - var pk groth16_bls24315.ProvingKey - err = groth16_bls24315.Setup(r1cs.(*cs_bls24315.R1CS), &pk, vk) + proof, err := groth16.Prove(r1cs, innerPk, pre_witness) if err != nil { t.Fatal(err) + } - _proof, err := groth16_bls24315.Prove(r1cs.(*cs_bls24315.R1CS), &pk, witness.Vector().(fr.Vector), backend.ProverConfig{}) + publicWitness, err := pre_witness.Public() if err != nil { t.Fatal(err) + } - proof.Ar = _proof.Ar - proof.Bs = _proof.Bs - proof.Krs = _proof.Krs - // before returning verifies that the proof passes on bls24315 - if err := groth16_bls24315.Verify(proof, vk, publicWitness.Vector().(fr.Vector)); err != nil { + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { t.Fatal(err) + } -} -type verifierCircuit struct { - InnerProof Proof - InnerVk VerifyingKey - Hash frontend.Variable -} + var circuit verifierCircuit + circuit.InnerVk.Allocate(innerVk) -func (circuit *verifierCircuit) Define(api frontend.API) error { + var witness verifierCircuit + witness.InnerProof.Assign(proof) + witness.InnerVk.Assign(innerVk) + witness.Hash = preComputeMimc(preImage) - // create the verifier cs - Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + assert := test.NewAssert(t) + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633), test.WithBackends(backend.GROTH16)) - return nil } -func TestVerifier(t *testing.T) { +func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls24315.VerifyingKey - var innerProof groth16_bls24315.Proof - generateBls24315InnerProof(t, &innerVk, &innerProof) // get public inputs of the inner proof + // create a mock cs: knowing the preimage of a hash using mimc + var MimcCircuit mimcCircuit + _r1cs, err := frontend.Compile(ecc.BLS24_315.ScalarField(), r1cs.NewBuilder, &MimcCircuit) + if err != nil { + b.Fatal(err) - // create an empty cs - var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls24315.G1Affine, len(innerVk.G1.K)) + } - // create assignment, the private part consists of the proof, - // the public part is exactly the public part of the inner proof, - // up to the renaming of the inner ONE_WIRE to not conflict with the one wire of the outer proof. - var witness verifierCircuit - witness.InnerProof.Ar.Assign(&innerProof.Ar) - witness.InnerProof.Krs.Assign(&innerProof.Krs) - witness.InnerProof.Bs.Assign(&innerProof.Bs) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS24_315.ScalarField()) + if err != nil { + b.Fatal(err) + + } - witness.InnerVk.Assign(&innerVk) + innerPk, innerVk, err := groth16.Setup(_r1cs) + if err != nil { + b.Fatal(err) - witness.Hash = publicHash + } - // verifies the cs - assert := test.NewAssert(t) + proof, err := groth16.Prove(_r1cs, innerPk, pre_witness) + if err != nil { + b.Fatal(err) - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) + } -} + publicWitness, err := pre_witness.Public() + if err != nil { + b.Fatal(err) -func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls24315.VerifyingKey - var innerProof groth16_bls24315.Proof - generateBls24315InnerProof(nil, &innerVk, &innerProof) // get public inputs of the inner proof + } + + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { + b.Fatal(err) + + } - // create an empty cs var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls24315.G1Affine, len(innerVk.G1.K)) + circuit.InnerVk.Allocate(innerVk) var ccs constraint.ConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_633.ScalarField(), r1cs.NewBuilder, &circuit) + ccs, err = frontend.Compile(ecc.BW6_633.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + b.Fatal(err) + + } + } + b.Log(ccs.GetNbConstraints()) + } var tVariable reflect.Type func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() + } diff --git a/std/hash/hash.go b/std/hash/hash.go index 141d0fcf18..307a7fc4b5 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -17,14 +17,20 @@ limitations under the License. // Package hash provides an interface that hash functions (as gadget) should implement. package hash -import "github.com/consensys/gnark/frontend" - -type Hash interface { +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" +) +// FieldHasher hashes inputs into a short digest. This interface mocks +// [BinaryHasher], but is more suitable in-circuit by assuming the inputs are +// scalar field elements and outputs digest as a field element. Such hash +// functions are for examle Poseidon, MiMC etc. +type FieldHasher interface { // Sum computes the hash of the internal state of the hash function. Sum() frontend.Variable - // Write populate the internal state of the hash function with data. + // Write populate the internal state of the hash function with data. The inputs are native field elements. Write(data ...frontend.Variable) // Reset empty the internal state and put the intermediate state to zero. @@ -32,3 +38,28 @@ type Hash interface { } var BuilderRegistry = make(map[string]func(api frontend.API) (Hash, error)) + +// BinaryHasher hashes inputs into a short digest. It takes as inputs bytes and +// outputs byte array whose length depends on the underlying hash function. For +// SNARK-native hash functions use [FieldHasher]. +type BinaryHasher interface { + // Sum finalises the current hash and returns the digest. + Sum() []uints.U8 + + // Write writes more bytes into the current hash state. + Write([]uints.U8) + + // Size returns the number of bytes this hash function returns in a call to + // [BinaryHasher.Sum]. + Size() int +} + +// BinaryFixedLengthHasher is like [BinaryHasher], but assumes the length of the +// input is not full length as defined during compile time. This allows to +// compute digest of variable-length input, unlike [BinaryHasher] which assumes +// the length of the input is the total number of bytes written. +type BinaryFixedLengthHasher interface { + BinaryHasher + // FixedLengthSum returns digest of the first length bytes. + FixedLengthSum(length frontend.Variable) []uints.U8 +} diff --git a/std/hash/sha2/sha2.go b/std/hash/sha2/sha2.go new file mode 100644 index 0000000000..edb261e621 --- /dev/null +++ b/std/hash/sha2/sha2.go @@ -0,0 +1,89 @@ +// Package sha2 implements SHA2 hash computation. +// +// This package extends the SHA2 permutation function [sha2] into a full SHA2 +// hash. +package sha2 + +import ( + "encoding/binary" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/std/permutation/sha2" +) + +var _seed = uints.NewU32Array([]uint32{ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +}) + +type digest struct { + uapi *uints.BinaryField[uints.U32] + in []uints.U8 +} + +func New(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U32](api) + if err != nil { + return nil, err + } + return &digest{uapi: uapi}, nil +} + +func (d *digest) Write(data []uints.U8) { + d.in = append(d.in, data...) +} + +func (d *digest) padded(bytesLen int) []uints.U8 { + zeroPadLen := 55 - bytesLen%64 + if zeroPadLen < 0 { + zeroPadLen += 64 + } + if cap(d.in) < len(d.in)+9+zeroPadLen { + // in case this is the first time this method is called increase the + // capacity of the slice to fit the padding. + d.in = append(d.in, make([]uints.U8, 9+zeroPadLen)...) + d.in = d.in[:len(d.in)-9-zeroPadLen] + } + buf := d.in + buf = append(buf, uints.NewU8(0x80)) + buf = append(buf, uints.NewU8Array(make([]uint8, zeroPadLen))...) + lenbuf := make([]uint8, 8) + binary.BigEndian.PutUint64(lenbuf, uint64(8*bytesLen)) + buf = append(buf, uints.NewU8Array(lenbuf)...) + return buf +} + +func (d *digest) Sum() []uints.U8 { + var runningDigest [8]uints.U32 + var buf [64]uints.U8 + copy(runningDigest[:], _seed) + padded := d.padded(len(d.in)) + for i := 0; i < len(padded)/64; i++ { + copy(buf[:], padded[i*64:(i+1)*64]) + runningDigest = sha2.Permute(d.uapi, runningDigest, buf) + } + var ret []uints.U8 + for i := range runningDigest { + ret = append(ret, d.uapi.UnpackMSB(runningDigest[i])...) + } + return ret +} + +func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { + panic("TODO") + // we need to do two things here -- first the padding has to be put to the + // right place. For that we need to know how many blocks we have used. We + // need to fit at least 9 more bytes (padding byte and 8 bytes for input + // length). Knowing the block, we have to keep running track if the current + // block is the expected one. + // + // idea - have a mask for blocks where 1 is only for the block we want to + // use. +} + +func (d *digest) Reset() { + d.in = nil +} + +func (d *digest) Size() int { return 32 } diff --git a/std/hash/sha2/sha2_test.go b/std/hash/sha2/sha2_test.go new file mode 100644 index 0000000000..d4acf5baf3 --- /dev/null +++ b/std/hash/sha2/sha2_test.go @@ -0,0 +1,50 @@ +package sha2 + +import ( + "crypto/sha256" + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +type sha2Circuit struct { + In []uints.U8 + Expected [32]uints.U8 +} + +func (c *sha2Circuit) Define(api frontend.API) error { + h, err := New(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + h.Write(c.In) + res := h.Sum() + if len(res) != 32 { + return fmt.Errorf("not 32 bytes") + } + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestSHA2(t *testing.T) { + bts := make([]byte, 310) + dgst := sha256.Sum256(bts) + witness := sha2Circuit{ + In: uints.NewU8Array(bts), + } + copy(witness.Expected[:], uints.NewU8Array(dgst[:])) + err := test.IsSolved(&sha2Circuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } +} diff --git a/std/hints.go b/std/hints.go index 8fc1222781..d807b1f024 100644 --- a/std/hints.go +++ b/std/hints.go @@ -3,11 +3,16 @@ package std import ( "sync" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/std/algebra/sw_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" + "github.com/consensys/gnark/std/evmprecompiles" + "github.com/consensys/gnark/std/internal/logderivarg" "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/bitslice" "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/rangecheck" + "github.com/consensys/gnark/std/selector" ) var registerOnce sync.Once @@ -21,14 +26,16 @@ func RegisterHints() { } func registerHints() { - // note that importing these packages may already trigger a call to hint.Register(...) - hint.Register(sw_bls24315.DecomposeScalarG1) - hint.Register(sw_bls12377.DecomposeScalarG1) - hint.Register(sw_bls24315.DecomposeScalarG2) - hint.Register(sw_bls12377.DecomposeScalarG2) - hint.Register(bits.NTrits) - hint.Register(bits.NNAF) - hint.Register(bits.IthBit) - hint.Register(bits.NBits) - hint.Register(emulated.GetHints()...) + // note that importing these packages may already trigger a call to solver.RegisterHint(...) + solver.RegisterHint(sw_bls24315.DecomposeScalarG1) + solver.RegisterHint(sw_bls12377.DecomposeScalarG1) + solver.RegisterHint(sw_bls24315.DecomposeScalarG2) + solver.RegisterHint(sw_bls12377.DecomposeScalarG2) + solver.RegisterHint(bits.GetHints()...) + solver.RegisterHint(selector.GetHints()...) + solver.RegisterHint(emulated.GetHints()...) + solver.RegisterHint(rangecheck.GetHints()...) + solver.RegisterHint(evmprecompiles.GetHints()...) + solver.RegisterHint(logderivarg.GetHints()...) + solver.RegisterHint(bitslice.GetHints()...) } diff --git a/std/hints_test.go b/std/hints_test.go index f88779b231..64bf7b362d 100644 --- a/std/hints_test.go +++ b/std/hints_test.go @@ -10,7 +10,7 @@ func ExampleRegisterHints() { var ccs constraint.ConstraintSystem // since package bits is not imported, the hint NNAF is not registered - // --> hint.Register(bits.NNAF) + // --> solver.RegisterHint(bits.NNAF) // rather than to keep track on which hints are needed, a prover/solver service can register all // gnark/std hints with this call RegisterHints() diff --git a/std/internal/logderivarg/logderivarg.go b/std/internal/logderivarg/logderivarg.go new file mode 100644 index 0000000000..c3741afcae --- /dev/null +++ b/std/internal/logderivarg/logderivarg.go @@ -0,0 +1,220 @@ +// Package logderivarg implements log-derivative argument. +// +// The log-derivative argument was described in [Haböck22] as an improvement +// over [BCG+18]. In [BCG+18], it was shown that to show inclusion of a multiset +// S in T, one can show +// +// ∏_{f∈F} (x-f)^count(f, S) == ∏_{s∈S} x-s, +// +// where function `count` counts the number of occurences of f in S. The problem +// with this approach is the high cost for exponentiating the left-hand side of +// the equation. However, in [Haböck22] it was shown that when avoiding the +// poles, we can perform the same check for the log-derivative variant of the +// equation: +// +// ∑_{f∈F} count(f,S)/(x-f) == ∑_{s∈S} 1/(x-s). +// +// Additionally, when the entries of both S and T are vectors, then instead we +// can check random linear combinations. So, when F is a matrix and S is a +// multiset of its rows, we first generate random linear coefficients (r_1, ..., +// r_n) and check +// +// ∑_{f∈F} count(f,S)/(x-∑_{i∈[n]}r_i*f_i) == ∑_{s∈S} 1/(x-∑_{i∈[n]}r_i*s_i). +// +// This package is a low-level primitive for building more extensive gadgets. It +// only checks the last equation, but the tables and queries should be built by +// the users. +// +// NB! The package doesn't check that the entries in table F are unique. +// +// [BCG+18]: https://eprint.iacr.org/2018/380 +// [Haböck22]: https://eprint.iacr.org/2022/1530 +package logderivarg + +// TODO: we handle both constant and variable tables. But for variable tables we +// have to ensure that all the table entries differ! Right now isn't a problem +// because everywhere we build we also have indices which ensure uniqueness. I +// guess the best approach is to have safe and unsafe versions where the safe +// version performs additional sorting. But that is really really expensive as +// we have to show that all sorted values ara monotonically increasing. + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/std/internal/multicommit" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hints used in this package +func GetHints() []solver.Hint { + return []solver.Hint{countHint} +} + +// Table is a vector of vectors. +type Table [][]frontend.Variable + +// AsTable returns a vector as a single-column table. +func AsTable(vector []frontend.Variable) Table { + ret := make([][]frontend.Variable, len(vector)) + for i := range vector { + ret[i] = []frontend.Variable{vector[i]} + } + return ret +} + +// Build builds the argument using the table and queries. If both table and +// queries are multiple-column, then also samples coefficients for the random +// linear combinations. +func Build(api frontend.API, table Table, queries Table) error { + if len(table) == 0 { + return fmt.Errorf("table empty") + } + nbRow := len(table[0]) + constTable := true + countInputs := []frontend.Variable{len(table), nbRow} + for i := range table { + if len(table[i]) != nbRow { + return fmt.Errorf("table row length mismatch") + } + if constTable { + for j := range table[i] { + if _, isConst := api.Compiler().ConstantValue(table[i][j]); !isConst { + constTable = false + } + } + } + countInputs = append(countInputs, table[i]...) + } + for i := range queries { + if len(queries[i]) != nbRow { + return fmt.Errorf("query row length mismatch") + } + countInputs = append(countInputs, queries[i]...) + } + exps, err := api.NewHint(countHint, len(table), countInputs...) + if err != nil { + return fmt.Errorf("hint: %w", err) + } + + var toCommit []frontend.Variable + if !constTable { + for i := range table { + toCommit = append(toCommit, table[i]...) + } + } + for i := range queries { + toCommit = append(toCommit, queries[i]...) + } + toCommit = append(toCommit, exps...) + + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + rowCoeffs, challenge := randLinearCoefficients(api, nbRow, commitment) + var lp frontend.Variable = 0 + for i := range table { + tmp := api.DivUnchecked(exps[i], api.Sub(challenge, randLinearCombination(api, rowCoeffs, table[i]))) + lp = api.Add(lp, tmp) + } + var rp frontend.Variable = 0 + for i := range queries { + tmp := api.Inverse(api.Sub(challenge, randLinearCombination(api, rowCoeffs, queries[i]))) + rp = api.Add(rp, tmp) + } + api.AssertIsEqual(lp, rp) + return nil + }, toCommit...) + return nil +} + +func randLinearCoefficients(api frontend.API, nbRow int, commitment frontend.Variable) (rowCoeffs []frontend.Variable, challenge frontend.Variable) { + if nbRow == 1 { + return []frontend.Variable{1}, commitment + } + hasher, err := mimc.NewMiMC(api) + if err != nil { + panic(err) + } + rowCoeffs = make([]frontend.Variable, nbRow) + for i := 0; i < nbRow; i++ { + hasher.Reset() + hasher.Write(i+1, commitment) + rowCoeffs[i] = hasher.Sum() + } + return rowCoeffs, commitment +} + +func randLinearCombination(api frontend.API, rowCoeffs []frontend.Variable, row []frontend.Variable) frontend.Variable { + if len(rowCoeffs) != len(row) { + panic("coefficient count mismatch") + } + var res frontend.Variable = 0 + for i := range rowCoeffs { + res = api.Add(res, api.Mul(rowCoeffs[i], row[i])) + } + return res +} + +func countHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) <= 2 { + return fmt.Errorf("at least two input required") + } + if !inputs[0].IsInt64() { + return fmt.Errorf("first element must be length of table") + } + nbTable := int(inputs[0].Int64()) + if !inputs[1].IsInt64() { + return fmt.Errorf("first element must be length of row") + } + nbRow := int(inputs[1].Int64()) + if len(inputs) < 2+nbTable { + return fmt.Errorf("input doesn't fit table") + } + if len(outputs) != nbTable { + return fmt.Errorf("output not table size") + } + if (len(inputs)-2-nbTable*nbRow)%nbRow != 0 { + return fmt.Errorf("query count not full integer") + } + nbQueries := (len(inputs) - 2 - nbTable*nbRow) / nbRow + if nbQueries <= 0 { + return fmt.Errorf("at least one query required") + } + nbBytes := (m.BitLen() + 7) / 8 + buf := make([]byte, nbBytes*nbRow) + histo := make(map[string]int64, nbTable) // string key as big ints not comparable + for i := 0; i < nbTable; i++ { + for j := 0; j < nbRow; j++ { + inputs[2+nbRow*i+j].FillBytes(buf[j*nbBytes : (j+1)*nbBytes]) + } + k := string(buf) + if _, ok := histo[k]; ok { + return fmt.Errorf("duplicate key") + } + histo[k] = 0 + } + for i := 0; i < nbQueries; i++ { + for j := 0; j < nbRow; j++ { + inputs[2+nbRow*nbTable+nbRow*i+j].FillBytes(buf[j*nbBytes : (j+1)*nbBytes]) + } + k := string(buf) + v, ok := histo[k] + if !ok { + return fmt.Errorf("query element not in table") + } + v++ + histo[k] = v + } + for i := 0; i < nbTable; i++ { + for j := 0; j < nbRow; j++ { + inputs[2+nbRow*i+j].FillBytes(buf[j*nbBytes : (j+1)*nbBytes]) + } + outputs[i].Set(big.NewInt(histo[string(buf)])) + } + return nil +} diff --git a/std/internal/logderivprecomp/logderivprecomp.go b/std/internal/logderivprecomp/logderivprecomp.go new file mode 100644 index 0000000000..d768c10a32 --- /dev/null +++ b/std/internal/logderivprecomp/logderivprecomp.go @@ -0,0 +1,127 @@ +// Package logderivprecomp allows computing functions using precomputation. +// +// Instead of computing binary functions and checking that the result is +// correctly constrained, we instead can precompute all valid values of a +// function and then perform lookup to obtain the result. For example, for the +// XOR function we would naively otherwise have to split the inputs into bits, +// XOR one-by-one and recombine. +// +// With this package, we can instead compute all results for two inputs of +// length 8 bit and then just perform a lookup on the inputs. +// +// We use the [logderivarg] package for the actual log-derivative argument. +package logderivprecomp + +import ( + "fmt" + "math/big" + "reflect" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/internal/logderivarg" +) + +type ctxPrecomputedKey struct{ fn uintptr } + +// Precomputed holds all precomputed function values and queries. +type Precomputed struct { + api frontend.API + compute solver.Hint + queries []frontend.Variable + rets []uint +} + +// New returns a new [Precomputed]. It defers the log-derivative argument. +func New(api frontend.API, fn solver.Hint, rets []uint) (*Precomputed, error) { + kv, ok := api.Compiler().(kvstore.Store) + if !ok { + panic("builder should implement key-value store") + } + ch := kv.GetKeyValue(ctxPrecomputedKey{fn: reflect.ValueOf(fn).Pointer()}) + if ch != nil { + if prt, ok := ch.(*Precomputed); ok { + return prt, nil + } else { + panic("stored rangechecker is not valid") + } + } + // check that the output lengths fit into a single element + var s uint = 16 + for _, v := range rets { + s += v + } + if s >= uint(api.Compiler().FieldBitLen()) { + return nil, fmt.Errorf("result doesn't fit into field element") + } + t := &Precomputed{ + api: api, + compute: fn, + queries: nil, + rets: rets, + } + kv.SetKeyValue(ctxPrecomputedKey{fn: reflect.ValueOf(fn).Pointer()}, t) + api.Compiler().Defer(t.build) + return t, nil +} + +func (t *Precomputed) pack(x, y frontend.Variable, rets []frontend.Variable) frontend.Variable { + shift := big.NewInt(1 << 8) + packed := t.api.Add(x, t.api.Mul(y, shift)) + for i := range t.rets { + shift.Lsh(shift, t.rets[i]) + packed = t.api.Add(packed, t.api.Mul(rets[i], shift)) + } + return packed +} + +// Query +func (t *Precomputed) Query(x, y frontend.Variable) []frontend.Variable { + // we don't have to check here. We assume the inputs are range checked and + // range check the output. + rets, err := t.api.Compiler().NewHint(t.compute, len(t.rets), x, y) + if err != nil { + panic(err) + } + packed := t.pack(x, y, rets) + t.queries = append(t.queries, packed) + return rets +} + +func (t *Precomputed) buildTable() []frontend.Variable { + tmp := new(big.Int) + shift := new(big.Int) + tbl := make([]frontend.Variable, 65536) + inputs := []*big.Int{big.NewInt(0), big.NewInt(0)} + outputs := make([]*big.Int, len(t.rets)) + for i := range outputs { + outputs[i] = new(big.Int) + } + for x := int64(0); x < 256; x++ { + inputs[0].SetInt64(x) + for y := int64(0); y < 256; y++ { + shift.SetInt64(1 << 8) + i := x | (y << 8) + inputs[1].SetInt64(y) + if err := t.compute(t.api.Compiler().Field(), inputs, outputs); err != nil { + panic(err) + } + tblval := new(big.Int).SetInt64(i) + for j := range t.rets { + shift.Lsh(shift, t.rets[j]) + tblval.Add(tblval, tmp.Mul(outputs[j], shift)) + } + tbl[i] = tblval + } + } + return tbl +} + +func (t *Precomputed) build(api frontend.API) error { + if len(t.queries) == 0 { + return nil + } + table := t.buildTable() + return logderivarg.Build(t.api, logderivarg.AsTable(table), logderivarg.AsTable(t.queries)) +} diff --git a/std/internal/logderivprecomp/logderivprecomp_test.go b/std/internal/logderivprecomp/logderivprecomp_test.go new file mode 100644 index 0000000000..8de2755f1a --- /dev/null +++ b/std/internal/logderivprecomp/logderivprecomp_test.go @@ -0,0 +1,55 @@ +package logderivprecomp + +import ( + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type TestXORCircuit struct { + X, Y [100]frontend.Variable + Res [100]frontend.Variable +} + +func (c *TestXORCircuit) Define(api frontend.API) error { + tbl, err := New(api, xorHint, []uint{8}) + if err != nil { + return err + } + for i := range c.X { + res := tbl.Query(c.X[i], c.Y[i]) + api.AssertIsEqual(res[0], c.Res[i]) + } + return nil +} + +func xorHint(_ *big.Int, inputs, outputs []*big.Int) error { + outputs[0].Xor(inputs[0], inputs[1]) + return nil +} + +func TestXor(t *testing.T) { + assert := test.NewAssert(t) + bound := big.NewInt(255) + var xs, ys, ress [100]frontend.Variable + for i := range xs { + x, _ := rand.Int(rand.Reader, bound) + y, _ := rand.Int(rand.Reader, bound) + ress[i] = new(big.Int).Xor(x, y) + xs[i] = x + ys[i] = y + } + witness := &TestXORCircuit{X: xs, Y: ys, Res: ress} + assert.ProverSucceeded(&TestXORCircuit{}, witness, + test.WithBackends(backend.GROTH16), + test.WithSolverOpts(solver.WithHints(xorHint)), + test.NoFuzzing(), + test.NoSerialization(), + test.WithCurves(ecc.BN254)) +} diff --git a/std/internal/multicommit/doc_test.go b/std/internal/multicommit/doc_test.go new file mode 100644 index 0000000000..42eb8d9e99 --- /dev/null +++ b/std/internal/multicommit/doc_test.go @@ -0,0 +1,104 @@ +package multicommit_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/internal/multicommit" +) + +// MultipleCommitmentCircuit is an example circuit showing usage of multiple +// independent commitments in-circuit. +type MultipleCommitmentsCircuit struct { + Secrets [4]frontend.Variable +} + +func (c *MultipleCommitmentsCircuit) Define(api frontend.API) error { + // first callback receives first unique commitment derived from the root commitment + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // compute (X-s[0]) * (X-s[1]) for a random X + res := api.Mul(api.Sub(commitment, c.Secrets[0]), api.Sub(commitment, c.Secrets[1])) + api.AssertIsDifferent(res, 0) + return nil + }, c.Secrets[:2]...) + + // second callback receives second unique commitment derived from the root commitment + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // compute (X-s[2]) * (X-s[3]) for a random X + res := api.Mul(api.Sub(commitment, c.Secrets[2]), api.Sub(commitment, c.Secrets[3])) + api.AssertIsDifferent(res, 0) + return nil + }, c.Secrets[2:4]...) + + // we do not have to pass any variables in if other calls to [WithCommitment] have + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // compute (X-s[0]) for a random X + api.AssertIsDifferent(api.Sub(commitment, c.Secrets[0]), 0) + return nil + }) + + // we can share variables between the callbacks + var shared, stored frontend.Variable + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + shared = api.Add(c.Secrets[0], commitment) + stored = commitment + return nil + }) + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsEqual(api.Sub(shared, stored), c.Secrets[0]) + return nil + }) + return nil +} + +// Full written on how to use multiple commitments in a circuit. +func ExampleWithCommitment() { + circuit := MultipleCommitmentsCircuit{} + assignment := MultipleCommitmentsCircuit{Secrets: [4]frontend.Variable{1, 2, 3, 4}} + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: + // compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/internal/multicommit/nativecommit.go b/std/internal/multicommit/nativecommit.go new file mode 100644 index 0000000000..7e641c3189 --- /dev/null +++ b/std/internal/multicommit/nativecommit.go @@ -0,0 +1,113 @@ +// Package multicommit implements commitment expansion. +// +// If the builder implements [frontend.Committer] interface, then we can commit +// to the variables and get a commitment which can be used as a unique +// randomness in the circuit. For current builders implementing this interface, +// the function can only be called once in a circuit. This makes it difficult to +// compose different gadgets which require randomness. +// +// This package extends the commitment interface by allowing to receive several +// functions unique commitment multiple times. It does this by collecting all +// variables to commit and the callbacks which want to access a commitment. Then +// we internally defer a function which computes the commitment over all input +// committed variables and then uses this commitment to derive a per-callback +// unique commitment. The callbacks are then called with these unique derived +// commitments instead. +package multicommit + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/hash/mimc" +) + +type multicommiter struct { + closed bool + vars []frontend.Variable + cbs []WithCommitmentFn +} + +type ctxMulticommiterKey struct{} + +// getCached gets the cached committer from the key-value storage. If it is not +// there then creates, stores and defers it, and then returns. +func getCached(api frontend.API) *multicommiter { + kv, ok := api.(kvstore.Store) + if !ok { + // if the builder doesn't implement key-value store then cannot store + // multi-committer in cache. + panic("builder should implement key-value store") + } + mc := kv.GetKeyValue(ctxMulticommiterKey{}) + if mc != nil { + if mct, ok := mc.(*multicommiter); ok { + return mct + } else { + panic("stored multicommiter is of invalid type") + } + } + mct := &multicommiter{} + kv.SetKeyValue(ctxMulticommiterKey{}, mct) + api.Compiler().Defer(mct.commitAndCall) + return mct +} + +func (mct *multicommiter) commitAndCall(api frontend.API) error { + // close collecting input in case anyone wants to check more variables to commit to. + mct.closed = true + if len(mct.cbs) == 0 { + // shouldn't happen. we defer this function on creating multicommitter + // instance. It is probably some race. + panic("calling commiter with zero callbacks") + } + commiter, ok := api.Compiler().(frontend.Committer) + if !ok { + panic("compiler doesn't implement frontend.Committer") + } + cmt, err := commiter.Commit(mct.vars...) + if err != nil { + return fmt.Errorf("commit: %w", err) + } + if len(mct.cbs) == 1 { + if err = mct.cbs[0](api, cmt); err != nil { + return fmt.Errorf("single callback: %w", err) + } + } else { + hasher, err := mimc.NewMiMC(api) + if err != nil { + return fmt.Errorf("new hasher: %w", err) + } + for i, cb := range mct.cbs { + hasher.Reset() + hasher.Write(i+1, cmt) + localcmt := hasher.Sum() + if err = cb(api, localcmt); err != nil { + return fmt.Errorf("with commitment callback %d: %w", i, err) + } + } + } + return nil +} + +// WithCommitmentFn is the function which is called asynchronously after all +// variables have been committed to. See [WithCommitment] for scheduling a +// function of this type. Every called functions received a distinct commitment +// built from a single root. +// +// It is invalid to call [WithCommitment] in this method recursively and this +// leads to panic. However, the method can call defer for other callbacks. +type WithCommitmentFn func(api frontend.API, commitment frontend.Variable) error + +// WithCommitment schedules the function cb to be called with a unique +// commitment. We append the variables committedVariables to be committed to +// with the native [frontend.Committer] interface. +func WithCommitment(api frontend.API, cb WithCommitmentFn, committedVariables ...frontend.Variable) { + mct := getCached(api) + if mct.closed { + panic("called WithCommitment recursively") + } + mct.vars = append(mct.vars, committedVariables...) + mct.cbs = append(mct.cbs, cb) +} diff --git a/std/internal/multicommit/nativecommit_test.go b/std/internal/multicommit/nativecommit_test.go new file mode 100644 index 0000000000..4b570b7f33 --- /dev/null +++ b/std/internal/multicommit/nativecommit_test.go @@ -0,0 +1,72 @@ +package multicommit + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type noRecursionCircuit struct { + X frontend.Variable +} + +func (c *noRecursionCircuit) Define(api frontend.API) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { return nil }, commitment) + return nil + }, c.X) + return nil +} + +func TestNoRecursion(t *testing.T) { + circuit := noRecursionCircuit{} + assert := test.NewAssert(t) + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + assert.Error(err) +} + +type multipleCommitmentCircuit struct { + X frontend.Variable +} + +func (c *multipleCommitmentCircuit) Define(api frontend.API) error { + var stored frontend.Variable + // first callback receives first unique commitment derived from the root commitment + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsDifferent(c.X, commitment) + stored = commitment + return nil + }, c.X) + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsDifferent(stored, commitment) + return nil + }, c.X) + return nil +} + +func TestMultipleCommitments(t *testing.T) { + circuit := multipleCommitmentCircuit{} + assignment := multipleCommitmentCircuit{X: 10} + assert := test.NewAssert(t) + assert.ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) // right now PLONK doesn't implement commitment +} + +type noCommitVariable struct { + X frontend.Variable +} + +func (c *noCommitVariable) Define(api frontend.API) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { return nil }) + return nil +} + +func TestNoCommitVariable(t *testing.T) { + circuit := noCommitVariable{} + assert := test.NewAssert(t) + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + assert.Error(err) +} diff --git a/std/lookup/logderivlookup/blueprint.go b/std/lookup/logderivlookup/blueprint.go new file mode 100644 index 0000000000..dc49ee77ec --- /dev/null +++ b/std/lookup/logderivlookup/blueprint.go @@ -0,0 +1,125 @@ +package logderivlookup + +import ( + "fmt" + + "github.com/consensys/gnark/constraint" +) + +// BlueprintLookupHint is a blueprint that facilitates the lookup of values in a table. +// It is essentially a hint to the solver, but enables storing the table entries only once. +type BlueprintLookupHint struct { + EntriesCalldata []uint32 +} + +// ensures BlueprintLookupHint implements the BlueprintSolvable interface +var _ constraint.BlueprintSolvable = (*BlueprintLookupHint)(nil) + +// func lookupHint(_ *big.Int, in []*big.Int, out []*big.Int) error { +// nbTable := len(in) - len(out) +// for i := 0; i < len(in)-nbTable; i++ { +// if !in[nbTable+i].IsInt64() { +// return fmt.Errorf("lookup query not integer") +// } +// ptr := int(in[nbTable+i].Int64()) +// if ptr >= nbTable { +// return fmt.Errorf("lookup query %d outside table size %d", ptr, nbTable) +// } +// out[i].Set(in[ptr]) +// } +// return nil +// } + +func (b *BlueprintLookupHint) Solve(s constraint.Solver, inst constraint.Instruction) error { + nbEntries := int(inst.Calldata[1]) + entries := make([]constraint.Element, nbEntries) + + // read the static entries from the blueprint + // TODO @gbotrel cache that. + offset, delta := 0, 0 + for i := 0; i < nbEntries; i++ { + entries[i], delta = s.Read(b.EntriesCalldata[offset:]) + offset += delta + } + + nbInputs := int(inst.Calldata[2]) + + // read the inputs from the instruction + inputs := make([]constraint.Element, nbInputs) + offset = 3 + for i := 0; i < nbInputs; i++ { + inputs[i], delta = s.Read(inst.Calldata[offset:]) + offset += delta + } + + // set the outputs + nbOutputs := nbInputs + + for i := 0; i < nbOutputs; i++ { + idx, isUint64 := s.Uint64(inputs[i]) + if !isUint64 || idx >= uint64(len(entries)) { + return fmt.Errorf("lookup query too large") + } + // we set the output wire to the value of the entry + s.SetValue(uint32(i+int(inst.WireOffset)), entries[idx]) + } + return nil +} + +func (b *BlueprintLookupHint) CalldataSize() int { + // variable size + return -1 +} +func (b *BlueprintLookupHint) NbConstraints() int { + return 0 +} + +// NbOutputs return the number of output wires this blueprint creates. +func (b *BlueprintLookupHint) NbOutputs(inst constraint.Instruction) int { + return int(inst.Calldata[2]) +} + +// Wires returns a function that walks the wires appearing in the blueprint. +// This is used by the level builder to build a dependency graph between instructions. +func (b *BlueprintLookupHint) WireWalker(inst constraint.Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + // depend on the table UP to the number of entries at time of instruction creation. + nbEntries := int(inst.Calldata[1]) + + // invoke the callback on each wire appearing in the table + j := 0 + for i := 0; i < nbEntries; i++ { + // first we have the length of the linear expression + n := int(b.EntriesCalldata[j]) + j++ + for k := 0; k < n; k++ { + t := constraint.Term{CID: b.EntriesCalldata[j], VID: b.EntriesCalldata[j+1]} + if !t.IsConstant() { + cb(t.VID) + } + j += 2 + } + } + + // invoke the callback on each wire appearing in the inputs + nbInputs := int(inst.Calldata[2]) + j = 3 + for i := 0; i < nbInputs; i++ { + // first we have the length of the linear expression + n := int(inst.Calldata[j]) + j++ + for k := 0; k < n; k++ { + t := constraint.Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]} + if !t.IsConstant() { + cb(t.VID) + } + j += 2 + } + } + + // finally we have the outputs + for i := 0; i < nbInputs; i++ { + cb(uint32(i + int(inst.WireOffset))) + } + } +} diff --git a/std/lookup/logderivlookup/doc_test.go b/std/lookup/logderivlookup/doc_test.go new file mode 100644 index 0000000000..e8ab1a01e4 --- /dev/null +++ b/std/lookup/logderivlookup/doc_test.go @@ -0,0 +1,90 @@ +package logderivlookup_test + +import ( + "crypto/rand" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/lookup/logderivlookup" +) + +type LookupCircuit struct { + Entries [1000]frontend.Variable + Queries, Expected [100]frontend.Variable +} + +func (c *LookupCircuit) Define(api frontend.API) error { + t := logderivlookup.New(api) + for i := range c.Entries { + t.Insert(c.Entries[i]) + } + results := t.Lookup(c.Queries[:]...) + if len(results) != len(c.Expected) { + return fmt.Errorf("length mismatch") + } + for i := range results { + api.AssertIsEqual(results[i], c.Expected[i]) + } + return nil +} + +func Example() { + field := ecc.BN254.ScalarField() + witness := LookupCircuit{} + bound := big.NewInt(int64(len(witness.Entries))) + for i := range witness.Entries { + witness.Entries[i], _ = rand.Int(rand.Reader, field) + } + for i := range witness.Queries { + q, _ := rand.Int(rand.Reader, bound) + witness.Queries[i] = q + witness.Expected[i] = new(big.Int).Set(witness.Entries[q.Int64()].(*big.Int)) + } + ccs, err := frontend.Compile(field, r1cs.NewBuilder, &LookupCircuit{}) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: + // compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/lookup/logderivlookup/logderivlookup.go b/std/lookup/logderivlookup/logderivlookup.go new file mode 100644 index 0000000000..08116feaf4 --- /dev/null +++ b/std/lookup/logderivlookup/logderivlookup.go @@ -0,0 +1,152 @@ +// Package logderiv implements append-only lookups using log-derivative +// argument. +// +// The lookup is based on log-derivative argument as described in [logderivarg]. +// The lookup table is a matrix where first column is the index and the second +// column the stored values: +// +// 1 x_1 +// 2 x_2 +// ... +// n x_n +// +// When performing a query for index i, the prover returns x_i from memory and +// stores (i, x_i) as a query. During the log-derivative argument building we +// check that all queried tuples (i, x_i) are included in the table. +// +// The complexity of the lookups is linear in the size of the table and the +// number of queries (O(n+m)). +package logderivlookup + +import ( + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/internal/logderivarg" +) + +// Table holds all the entries and queries. +type Table struct { + api frontend.API + + entries []frontend.Variable + immutable bool + results []result + + // each table has a unique blueprint + // the blueprint stores the lookup table entries once + // such that each query only need to store the indexes to lookup + bID constraint.BlueprintID + blueprint BlueprintLookupHint +} + +type result struct { + ind frontend.Variable + val frontend.Variable +} + +// New returns a new [*Table]. It additionally defers building the +// log-derivative argument. +func New(api frontend.API) *Table { + t := &Table{api: api} + api.Compiler().Defer(t.commit) + + // each table has a unique blueprint + t.bID = api.Compiler().AddBlueprint(&t.blueprint) + return t +} + +// Insert inserts variable val into the lookup table and returns its index as a +// constant. It panics if the table is already committed. +func (t *Table) Insert(val frontend.Variable) (index int) { + if t.immutable { + panic("inserting into committed lookup table") + } + t.entries = append(t.entries, val) + + // each time we insert a new entry, we update the blueprint + v := t.api.Compiler().ToCanonicalVariable(val) + v.Compress(&t.blueprint.EntriesCalldata) + + return len(t.entries) - 1 +} + +// Lookup lookups up values from the lookup tables given by the indices inds. It +// returns a variable for every index. It panics during compile time when +// looking up from a committed or empty table. It panics during solving time +// when the index is out of bounds. +func (t *Table) Lookup(inds ...frontend.Variable) (vals []frontend.Variable) { + if t.immutable { + panic("looking up from a committed lookup table") + } + if len(inds) == 0 { + return nil + } + if len(t.entries) == 0 { + panic("looking up from empty table") + } + return t.performLookup(inds) +} + +// performLookup performs the lookup and returns the resulting variables. +// underneath, it does use the blueprint to encode the lookup hint. +func (t *Table) performLookup(inds []frontend.Variable) []frontend.Variable { + // to build the instruction, we need to first encode its dependency as a calldata []uint32 slice. + // * calldata[0] is the length of the calldata, + // * calldata[1] is the number of entries in the table we consider. + // * calldata[2] is the number of queries (which is the number of indices we are looking up and the number of outputs we expect) + compiler := t.api.Compiler() + + calldata := make([]uint32, 3, 3+len(inds)*2+2) + calldata[1] = uint32(len(t.entries)) + calldata[2] = uint32(len(inds)) + + // encode inputs + for _, in := range inds { + v := compiler.ToCanonicalVariable(in) + v.Compress(&calldata) + } + + // by convention, first calldata is len of inputs + calldata[0] = uint32(len(calldata)) + + // now what we are left to do is add an instruction to the constraint system + // such that at solving time the blueprint can properly execute the lookup logic. + outputs := compiler.AddInstruction(t.bID, calldata) + + // sanity check + if len(outputs) != len(inds) { + panic("sanity check") + } + + // we need to return the variables corresponding to the outputs + internalVariables := make([]frontend.Variable, len(inds)) + lookupResult := make([]result, len(inds)) + + // we need to store the result of the lookup in the table + for i := range inds { + internalVariables[i] = compiler.InternalVariable(outputs[i]) + lookupResult[i] = result{ind: inds[i], val: internalVariables[i]} + } + t.results = append(t.results, lookupResult...) + return internalVariables +} + +func (t *Table) entryTable() [][]frontend.Variable { + tbl := make([][]frontend.Variable, len(t.entries)) + for i := range t.entries { + tbl[i] = []frontend.Variable{i, t.entries[i]} + } + return tbl +} + +func (t *Table) resultsTable() [][]frontend.Variable { + tbl := make([][]frontend.Variable, len(t.results)) + for i := range t.results { + tbl[i] = []frontend.Variable{t.results[i].ind, t.results[i].val} + } + return tbl +} + +func (t *Table) commit(api frontend.API) error { + return logderivarg.Build(api, t.entryTable(), t.resultsTable()) +} diff --git a/std/lookup/logderivlookup/logderivlookup_test.go b/std/lookup/logderivlookup/logderivlookup_test.go new file mode 100644 index 0000000000..0070f18352 --- /dev/null +++ b/std/lookup/logderivlookup/logderivlookup_test.go @@ -0,0 +1,65 @@ +package logderivlookup + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +type LookupCircuit struct { + Entries [1000]frontend.Variable + Queries, Expected [100]frontend.Variable +} + +func (c *LookupCircuit) Define(api frontend.API) error { + t := New(api) + for i := range c.Entries { + t.Insert(c.Entries[i]) + } + results := t.Lookup(c.Queries[:]...) + if len(results) != len(c.Expected) { + return fmt.Errorf("length mismatch") + } + for i := range results { + api.AssertIsEqual(results[i], c.Expected[i]) + } + return nil +} + +func TestLookup(t *testing.T) { + assert := test.NewAssert(t) + field := ecc.BN254.ScalarField() + witness := LookupCircuit{} + bound := big.NewInt(int64(len(witness.Entries))) + for i := range witness.Entries { + witness.Entries[i], _ = rand.Int(rand.Reader, field) + } + for i := range witness.Queries { + q, _ := rand.Int(rand.Reader, bound) + witness.Queries[i] = q + witness.Expected[i] = new(big.Int).Set(witness.Entries[q.Int64()].(*big.Int)) + } + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &LookupCircuit{}) + assert.NoError(err) + + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + assert.NoError(err) + + _, err = ccs.Solve(w) + assert.NoError(err) + + err = test.IsSolved(&LookupCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + + assert.ProverSucceeded(&LookupCircuit{}, &witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16, backend.PLONK)) +} diff --git a/std/math/bits/conversion.go b/std/math/bits/conversion.go index 20040b82d8..ca82c2773e 100644 --- a/std/math/bits/conversion.go +++ b/std/math/bits/conversion.go @@ -56,10 +56,10 @@ type baseConversionConfig struct { // BaseConversionOption configures the behaviour of scalar decomposition. type BaseConversionOption func(opt *baseConversionConfig) error -// WithNbDigits set the resulting number of digits to be used in the base conversion. -// nbDigits must be > 0. If nbDigits is lower than the length of full -// decomposition, then nbDigits least significant digits are returned. If the -// option is not set, then the full decomposition is returned. +// WithNbDigits sets the resulting number of digits (nbDigits) to be used in the base conversion. +// nbDigits must be > 0. If nbDigits is lower than the length of full decomposition and +// WithUnconstrainedOutputs option is not used, then this function generates an unsatisfiable +// constraint. If WithNbDigits option is not set, then the full decomposition is returned. func WithNbDigits(nbDigits int) BaseConversionOption { return func(opt *baseConversionConfig) error { if nbDigits <= 0 { diff --git a/std/math/bits/conversion_binary.go b/std/math/bits/conversion_binary.go index 5d93af06c6..693f477b3f 100644 --- a/std/math/bits/conversion_binary.go +++ b/std/math/bits/conversion_binary.go @@ -3,16 +3,9 @@ package bits import ( "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) -func init() { - // register hints - hint.Register(IthBit) - hint.Register(NBits) -} - // ToBinary is an alias of ToBase(api, Binary, v, opts) func ToBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) []frontend.Variable { return ToBase(api, Binary, v, opts...) @@ -63,18 +56,16 @@ func toBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOptio } } - // if a is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { - bits := make([]frontend.Variable, cfg.NbDigits) - for i := 0; i < len(bits); i++ { - bits[i] = c.Bit(i) - } - return bits + // when cfg.NbDigits == 1, v itself has to be a binary digit. This if clause + // saves one constraint. + if cfg.NbDigits == 1 { + api.AssertIsBoolean(v) + return []frontend.Variable{v} } c := big.NewInt(1) - bits, err := api.Compiler().NewHint(NBits, cfg.NbDigits, v) + bits, err := api.Compiler().NewHint(nBits, cfg.NbDigits, v) if err != nil { panic(err) } @@ -94,27 +85,3 @@ func toBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOptio return bits } - -// IthBit returns the i-tb bit the input. The function expects exactly two -// integer inputs i and n, takes the little-endian bit representation of n and -// returns its i-th bit. -func IthBit(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - result := results[0] - if !inputs[1].IsUint64() { - result.SetUint64(0) - return nil - } - - result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64())))) - return nil -} - -// NBits returns the first bits of the input. The number of returned bits is -// defined by the length of the results slice. -func NBits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - for i := 0; i < len(results); i++ { - results[i].SetUint64(uint64(n.Bit(i))) - } - return nil -} diff --git a/std/math/bits/conversion_ternary.go b/std/math/bits/conversion_ternary.go index 302868361b..d38095d2b2 100644 --- a/std/math/bits/conversion_ternary.go +++ b/std/math/bits/conversion_ternary.go @@ -4,18 +4,9 @@ import ( "math" "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) -// NTrits returns the first trits of the input. The number of returned trits is -// defined by the length of the results slice. -var NTrits = nTrits - -func init() { - hint.Register(NTrits) -} - // ToTernary is an alias of ToBase(api, Ternary, v, opts...) func ToTernary(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) []frontend.Variable { return ToBase(api, Ternary, v, opts...) @@ -67,26 +58,10 @@ func toTernary(api frontend.API, v frontend.Variable, opts ...BaseConversionOpti } } - // if a is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { - trits := make([]frontend.Variable, cfg.NbDigits) - // TODO using big.Int Text is likely not cheap - base3 := c.Text(3) - i := 0 - for j := len(base3) - 1; j >= 0 && i < len(trits); j-- { - trits[i] = int(base3[j] - 48) - i++ - } - for ; i < len(trits); i++ { - trits[i] = 0 - } - return trits - } - c := big.NewInt(1) b := big.NewInt(3) - trits, err := api.Compiler().NewHint(NTrits, cfg.NbDigits, v) + trits, err := api.Compiler().NewHint(nTrits, cfg.NbDigits, v) if err != nil { panic(err) } @@ -121,19 +96,3 @@ func AssertIsTrit(api frontend.API, v frontend.Variable) { y := api.Mul(api.Sub(1, v), api.Sub(2, v)) api.AssertIsEqual(api.Mul(v, y), 0) } - -func nTrits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - // TODO using big.Int Text method is likely not cheap - base3 := n.Text(3) - i := 0 - for j := len(base3) - 1; j >= 0 && i < len(results); j-- { - results[i].SetUint64(uint64(base3[j] - 48)) - i++ - } - for ; i < len(results); i++ { - results[i].SetUint64(0) - } - - return nil -} diff --git a/std/math/bits/conversion_test.go b/std/math/bits/conversion_test.go index a43ae6489d..47455b236d 100644 --- a/std/math/bits/conversion_test.go +++ b/std/math/bits/conversion_test.go @@ -31,6 +31,13 @@ func (c *toBinaryCircuit) Define(api frontend.API) error { func TestToBinary(t *testing.T) { assert := test.NewAssert(t) assert.ProverSucceeded(&toBinaryCircuit{}, &toBinaryCircuit{A: 5, B0: 1, B1: 0, B2: 1}) + + assert.ProverSucceeded(&toBinaryCircuit{}, &toBinaryCircuit{A: 3, B0: 1, B1: 1, B2: 0}) + + // prover fails when the binary representation of A has more than 3 bits + assert.ProverFailed(&toBinaryCircuit{}, &toBinaryCircuit{A: 8, B0: 0, B1: 0, B2: 0}) + + assert.ProverFailed(&toBinaryCircuit{}, &toBinaryCircuit{A: 10, B0: 0, B1: 1, B2: 0}) } type toTernaryCircuit struct { diff --git a/std/math/bits/hints.go b/std/math/bits/hints.go new file mode 100644 index 0000000000..bb3da6d13c --- /dev/null +++ b/std/math/bits/hints.go @@ -0,0 +1,114 @@ +package bits + +import ( + "errors" + "math/big" + + "github.com/consensys/gnark/constraint/solver" +) + +func GetHints() []solver.Hint { + return []solver.Hint{ + ithBit, + nBits, + nTrits, + nNaf, + } +} + +func init() { + solver.RegisterHint(GetHints()...) +} + +// IthBit returns the i-tb bit the input. The function expects exactly two +// integer inputs i and n, takes the little-endian bit representation of n and +// returns its i-th bit. +func ithBit(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + result := results[0] + if !inputs[1].IsUint64() { + result.SetUint64(0) + return nil + } + + result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64())))) + return nil +} + +// NBits returns the first bits of the input. The number of returned bits is +// defined by the length of the results slice. +func nBits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + for i := 0; i < len(results); i++ { + results[i].SetUint64(uint64(n.Bit(i))) + } + return nil +} + +// nTrits returns the first trits of the input. The number of returned trits is +// defined by the length of the results slice. +func nTrits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + // TODO using big.Int Text method is likely not cheap + base3 := n.Text(3) + i := 0 + for j := len(base3) - 1; j >= 0 && i < len(results); j-- { + results[i].SetUint64(uint64(base3[j] - 48)) + i++ + } + for ; i < len(results); i++ { + results[i].SetUint64(0) + } + + return nil +} + +// NNAF returns the NAF decomposition of the input. The number of digits is +// defined by the number of elements in the results slice. +func nNaf(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + return nafDecomposition(n, results) +} + +// nafDecomposition gets the naf decomposition of a big number +func nafDecomposition(a *big.Int, results []*big.Int) error { + if a == nil || a.Sign() == -1 { + return errors.New("invalid input to naf decomposition; negative (or nil) big.Int not supported") + } + + var zero, one, three big.Int + + one.SetUint64(1) + three.SetUint64(3) + + n := 0 + + // some buffers + var buf, aCopy big.Int + aCopy.Set(a) + + for aCopy.Cmp(&zero) != 0 && n < len(results) { + + // if aCopy % 2 == 0 + buf.And(&aCopy, &one) + + // aCopy even + if buf.Cmp(&zero) == 0 { + results[n].SetUint64(0) + } else { // aCopy odd + buf.And(&aCopy, &three) + if buf.IsUint64() && buf.Uint64() == 3 { + results[n].SetInt64(-1) + aCopy.Add(&aCopy, &one) + } else { + results[n].SetUint64(1) + } + } + aCopy.Rsh(&aCopy, 1) + n++ + } + for ; n < len(results); n++ { + results[n].SetUint64(0) + } + + return nil +} diff --git a/std/math/bits/naf.go b/std/math/bits/naf.go index b8e21ed965..56b4b9e468 100644 --- a/std/math/bits/naf.go +++ b/std/math/bits/naf.go @@ -1,21 +1,11 @@ package bits import ( - "errors" "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) -// NNAF returns the NAF decomposition of the input. The number of digits is -// defined by the number of elements in the results slice. -var NNAF = nNaf - -func init() { - hint.Register(NNAF) -} - // ToNAF returns the NAF decomposition of given input. // The non-adjacent form (NAF) of a number is a unique signed-digit representation, // in which non-zero values cannot be adjacent. For example, NAF(13) = [1, 0, -1, 0, 1]. @@ -32,25 +22,9 @@ func ToNAF(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) } } - // if v is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { - bits := make([]*big.Int, cfg.NbDigits) - for i := 0; i < len(bits); i++ { - bits[i] = big.NewInt(0) - } - if err := nafDecomposition(c, bits); err != nil { - panic(err) - } - res := make([]frontend.Variable, len(bits)) - for i := 0; i < len(bits); i++ { - res[i] = bits[i] - } - return res - } - c := big.NewInt(1) - bits, err := api.Compiler().NewHint(NNAF, cfg.NbDigits, v) + bits, err := api.Compiler().NewHint(nNaf, cfg.NbDigits, v) if err != nil { panic(err) } @@ -74,52 +48,3 @@ func ToNAF(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) return bits } - -func nNaf(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - return nafDecomposition(n, results) -} - -// nafDecomposition gets the naf decomposition of a big number -func nafDecomposition(a *big.Int, results []*big.Int) error { - if a == nil || a.Sign() == -1 { - return errors.New("invalid input to naf decomposition; negative (or nil) big.Int not supported") - } - - var zero, one, three big.Int - - one.SetUint64(1) - three.SetUint64(3) - - n := 0 - - // some buffers - var buf, aCopy big.Int - aCopy.Set(a) - - for aCopy.Cmp(&zero) != 0 && n < len(results) { - - // if aCopy % 2 == 0 - buf.And(&aCopy, &one) - - // aCopy even - if buf.Cmp(&zero) == 0 { - results[n].SetUint64(0) - } else { // aCopy odd - buf.And(&aCopy, &three) - if buf.IsUint64() && buf.Uint64() == 3 { - results[n].SetInt64(-1) - aCopy.Add(&aCopy, &one) - } else { - results[n].SetUint64(1) - } - } - aCopy.Rsh(&aCopy, 1) - n++ - } - for ; n < len(results); n++ { - results[n].SetUint64(0) - } - - return nil -} diff --git a/std/math/bitslice/hints.go b/std/math/bitslice/hints.go new file mode 100644 index 0000000000..d9797c6a9a --- /dev/null +++ b/std/math/bitslice/hints.go @@ -0,0 +1,34 @@ +package bitslice + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +func GetHints() []solver.Hint { + return []solver.Hint{ + partitionHint, + } +} + +func partitionHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two inputs") + } + if len(outputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + if !inputs[0].IsUint64() { + return fmt.Errorf("split location must be int") + } + split := uint(inputs[0].Uint64()) + div := new(big.Int).Lsh(big.NewInt(1), split) + outputs[0].QuoRem(inputs[1], div, outputs[1]) + return nil +} diff --git a/std/math/bitslice/opts.go b/std/math/bitslice/opts.go new file mode 100644 index 0000000000..fa333efb1a --- /dev/null +++ b/std/math/bitslice/opts.go @@ -0,0 +1,37 @@ +package bitslice + +import "fmt" + +type opt struct { + digits int + nocheck bool +} + +func parseOpts(opts ...Option) (*opt, error) { + o := new(opt) + for _, apply := range opts { + if err := apply(o); err != nil { + return nil, err + } + } + return o, nil +} + +type Option func(*opt) error + +func WithNbDigits(nbDigits int) Option { + return func(o *opt) error { + if nbDigits < 1 { + return fmt.Errorf("given number of digits %d smaller than 1", nbDigits) + } + o.digits = nbDigits + return nil + } +} + +func WithUnconstrainedOutputs() Option { + return func(o *opt) error { + o.nocheck = true + return nil + } +} diff --git a/std/math/bitslice/partition.go b/std/math/bitslice/partition.go new file mode 100644 index 0000000000..f20c6d3513 --- /dev/null +++ b/std/math/bitslice/partition.go @@ -0,0 +1,69 @@ +package bitslice + +import ( + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/rangecheck" +) + +// Partition partitions v into two parts splitted at bit numbered split. The +// following holds +// +// v = lower + 2^split * upper. +// +// The method enforces that lower < 2^split and upper < 2^split', where +// split'=nbScalar-split. When giving the option [WithNbDigits], we instead use +// the bound split'=nbDigits-split. +func Partition(api frontend.API, v frontend.Variable, split uint, opts ...Option) (lower, upper frontend.Variable) { + opt, err := parseOpts(opts...) + if err != nil { + panic(err) + } + // handle constant case + if vc, ok := api.Compiler().ConstantValue(v); ok { + if opt.digits > 0 && vc.BitLen() > opt.digits { + panic("input larger than bound") + } + if split == 0 { + return 0, vc + } + div := new(big.Int).Lsh(big.NewInt(1), split) + l, u := new(big.Int), new(big.Int) + u.QuoRem(vc, div, l) + return l, u + } + rh := rangecheck.New(api) + if split == 0 { + if opt.digits > 0 { + rh.Check(v, opt.digits) + } + return 0, v + } + ret, err := api.Compiler().NewHint(partitionHint, 2, split, v) + if err != nil { + panic(err) + } + + upper = ret[0] + lower = ret[1] + + if opt.nocheck { + if opt.digits > 0 { + rh.Check(v, opt.digits) + } + return + } + upperBound := api.Compiler().FieldBitLen() + if opt.digits > 0 { + upperBound = opt.digits + } + rh.Check(upper, upperBound) + rh.Check(lower, int(split)) + + m := big.NewInt(1) + m.Lsh(m, split) + composed := api.Add(lower, api.Mul(upper, m)) + api.AssertIsEqual(composed, v) + return +} diff --git a/std/math/bitslice/partition_test.go b/std/math/bitslice/partition_test.go new file mode 100644 index 0000000000..891c01e658 --- /dev/null +++ b/std/math/bitslice/partition_test.go @@ -0,0 +1,30 @@ +package bitslice + +import ( + "testing" + + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type partitionCircuit struct { + Split uint + In, ExpLower, ExpUpper frontend.Variable +} + +func (c *partitionCircuit) Define(api frontend.API) error { + lower, upper := Partition(api, c.In, c.Split) + api.AssertIsEqual(lower, c.ExpLower) + api.AssertIsEqual(upper, c.ExpUpper) + return nil +} + +func TestPartition(t *testing.T) { + assert := test.NewAssert(t) + // TODO: for some reason next fails with PLONK+FRI + assert.ProverSucceeded(&partitionCircuit{Split: 16}, &partitionCircuit{Split: 16, ExpUpper: 0xffff, ExpLower: 0x1234, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&partitionCircuit{Split: 0}, &partitionCircuit{Split: 0, ExpUpper: 0xffff1234, ExpLower: 0, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&partitionCircuit{Split: 32}, &partitionCircuit{Split: 32, ExpUpper: 0, ExpLower: 0xffff1234, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&partitionCircuit{Split: 4}, &partitionCircuit{Split: 4, ExpUpper: 0xffff123, ExpLower: 4, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) +} diff --git a/std/math/emulated/doc_example_field_test.go b/std/math/emulated/doc_example_field_test.go index b6f3b8a18e..338c18b4ec 100644 --- a/std/math/emulated/doc_example_field_test.go +++ b/std/math/emulated/doc_example_field_test.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/math/emulated" @@ -61,7 +62,7 @@ func ExampleField() { } else { fmt.Println("setup done") } - proof, err := groth16.Prove(ccs, pk, witnessData, backend.WithHints(emulated.GetHints()...)) + proof, err := groth16.Prove(ccs, pk, witnessData, backend.WithSolverOptions(solver.WithHints(emulated.GetHints()...))) if err != nil { panic(err) } else { diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index aa7d5f8cd7..4f4e3eb4bb 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -87,6 +87,50 @@ func testAssertIsLessEqualThan[T FieldParams](t *testing.T) { }, testName[T]()) } +type AssertIsLessEqualThanConstantCiruit[T FieldParams] struct { + L Element[T] + R *big.Int +} + +func (c *AssertIsLessEqualThanConstantCiruit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + R := f.NewElement(c.R) + f.AssertIsLessOrEqual(&c.L, R) + return nil +} + +func testAssertIsLessEqualThanConstant[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness AssertIsLessEqualThanConstantCiruit[T] + R, _ := rand.Int(rand.Reader, fp.Modulus()) + L, _ := rand.Int(rand.Reader, R) + circuit.R = R + witness.R = R + witness.L = ValueOf[T](L) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) + assert.Run(func(assert *test.Assert) { + var circuit, witness AssertIsLessEqualThanConstantCiruit[T] + R := new(big.Int).Set(fp.Modulus()) + L, _ := rand.Int(rand.Reader, R) + circuit.R = R + witness.R = R + witness.L = ValueOf[T](L) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, fmt.Sprintf("overflow/%s", testName[T]())) +} + +func TestAssertIsLessEqualThanConstant(t *testing.T) { + testAssertIsLessEqualThanConstant[Goldilocks](t) + testAssertIsLessEqualThanConstant[Secp256k1Fp](t) + testAssertIsLessEqualThanConstant[BN254Fp](t) +} + type AddCircuit[T FieldParams] struct { A, B, C Element[T] } @@ -800,3 +844,115 @@ func TestIssue348UnconstrainedLimbs(t *testing.T) { // inputs. assert.Error(err) } + +type AssertInRangeCircuit[T FieldParams] struct { + X Element[T] +} + +func (c *AssertInRangeCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + f.AssertIsInRange(&c.X) + return nil +} + +func TestAssertInRange(t *testing.T) { + testAssertIsInRange[Goldilocks](t) + testAssertIsInRange[Secp256k1Fp](t) + testAssertIsInRange[BN254Fp](t) +} + +func testAssertIsInRange[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + X, _ := rand.Int(rand.Reader, fp.Modulus()) + circuit := AssertInRangeCircuit[T]{} + witness := AssertInRangeCircuit[T]{X: ValueOf[T](X)} + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + witness2 := AssertInRangeCircuit[T]{X: ValueOf[T](0)} + t := 0 + for i := 0; i < int(fp.NbLimbs())-1; i++ { + L := new(big.Int).Lsh(big.NewInt(1), fp.BitsPerLimb()) + L.Sub(L, big.NewInt(1)) + witness2.X.Limbs[i] = L + t += int(fp.BitsPerLimb()) + } + highlimb := fp.Modulus().BitLen() - t + L := new(big.Int).Lsh(big.NewInt(1), uint(highlimb)) + L.Sub(L, big.NewInt(1)) + witness2.X.Limbs[fp.NbLimbs()-1] = L + assert.ProverFailed(&circuit, &witness2, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type IsZeroCircuit[T FieldParams] struct { + X, Y Element[T] + Zero frontend.Variable +} + +func (c *IsZeroCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + R := f.Add(&c.X, &c.Y) + api.AssertIsEqual(c.Zero, f.IsZero(R)) + return nil +} + +func TestIsZero(t *testing.T) { + testIsZero[Goldilocks](t) + testIsZero[Secp256k1Fp](t) + testIsZero[BN254Fp](t) +} + +func testIsZero[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + X, _ := rand.Int(rand.Reader, fp.Modulus()) + Y := new(big.Int).Sub(fp.Modulus(), X) + circuit := IsZeroCircuit[T]{} + assert.ProverSucceeded(&circuit, &IsZeroCircuit[T]{X: ValueOf[T](X), Y: ValueOf[T](Y), Zero: 1}, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&circuit, &IsZeroCircuit[T]{X: ValueOf[T](X), Y: ValueOf[T](0), Zero: 0}, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type SqrtCircuit[T FieldParams] struct { + X, Expected Element[T] +} + +func (c *SqrtCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Sqrt(&c.X) + f.AssertIsEqual(res, &c.Expected) + return nil +} + +func TestSqrt(t *testing.T) { + testSqrt[Goldilocks](t) + testSqrt[Secp256k1Fp](t) + testSqrt[BN254Fp](t) +} + +func testSqrt[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var X *big.Int + exp := new(big.Int) + for { + X, _ = rand.Int(rand.Reader, fp.Modulus()) + if exp.ModSqrt(X, fp.Modulus()) != nil { + break + } + } + assert.ProverSucceeded(&SqrtCircuit[T]{}, &SqrtCircuit[T]{X: ValueOf[T](X), Expected: ValueOf[T](exp)}, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go new file mode 100644 index 0000000000..9b888d6a4d --- /dev/null +++ b/std/math/emulated/emparams/emparams.go @@ -0,0 +1,201 @@ +// Package emparams contains emulation parameters for well known fields. +// +// We define some well-known parameters in this package for compatibility and +// ease of use. When needing to use parameters not defined in this package it is +// sufficient to define a new type implementing [FieldParams]. For example, as: +// +// type SmallField struct {} +// func (SmallField) NbLimbs() uint { return 1 } +// func (SmallField) BitsPerLimb() uint { return 11 } +// func (SmallField) IsPrime() bool { return true } +// func (SmallField) Modulus() *big.Int { return big.NewInt(1032) } +package emparams + +import ( + "crypto/elliptic" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/goldilocks" +) + +type fourLimbPrimeField struct{} + +func (fourLimbPrimeField) NbLimbs() uint { return 4 } +func (fourLimbPrimeField) BitsPerLimb() uint { return 64 } +func (fourLimbPrimeField) IsPrime() bool { return true } + +type sixLimbPrimeField struct{} + +func (sixLimbPrimeField) NbLimbs() uint { return 6 } +func (sixLimbPrimeField) BitsPerLimb() uint { return 64 } +func (sixLimbPrimeField) IsPrime() bool { return true } + +// Goldilocks provides type parametrization for field emulation: +// - limbs: 1 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffff00000001 (base 16) +// 18446744069414584321 (base 10) +type Goldilocks struct{} + +func (fp Goldilocks) NbLimbs() uint { return 1 } +func (fp Goldilocks) BitsPerLimb() uint { return 64 } +func (fp Goldilocks) IsPrime() bool { return true } +func (fp Goldilocks) Modulus() *big.Int { return goldilocks.Modulus() } + +// Secp256k1Fp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f (base 16) +// 115792089237316195423570985008687907853269984665640564039457584007908834671663 (base 10) +// +// This is the base field of the SECP256k1 curve. +type Secp256k1Fp struct{ fourLimbPrimeField } + +func (fp Secp256k1Fp) Modulus() *big.Int { return ecc.SECP256K1.BaseField() } + +// Secp256k1Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 (base 16) +// 115792089237316195423570985008687907852837564279074904382605163141518161494337 (base 10) +// +// This is the scalar field of the SECP256k1 curve. +type Secp256k1Fr struct{ fourLimbPrimeField } + +func (fp Secp256k1Fr) Modulus() *big.Int { return ecc.SECP256K1.ScalarField() } + +// BN254Fp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 (base 16) +// 21888242871839275222246405745257275088696311157297823662689037894645226208583 (base 10) +// +// This is the base field of the BN254 curve. +type BN254Fp struct{ fourLimbPrimeField } + +func (fp BN254Fp) Modulus() *big.Int { return ecc.BN254.BaseField() } + +// BN254Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 (base 16) +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 (base 10) +// +// This is the scalar field of the BN254 curve. +type BN254Fr struct{ fourLimbPrimeField } + +func (fp BN254Fr) Modulus() *big.Int { return ecc.BN254.ScalarField() } + +// BLS12377Fp provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 (base 16) +// 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 (base 10) +// +// This is the base field of the BLS12-377 curve. +type BLS12377Fp struct{ sixLimbPrimeField } + +func (fp BLS12377Fp) Modulus() *big.Int { return ecc.BLS12_377.BaseField() } + +// BLS12381Fp provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab (base 16) +// 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 (base 10) +// +// This is the base field of the BLS12-381 curve. +type BLS12381Fp struct{ sixLimbPrimeField } + +func (fp BLS12381Fp) Modulus() *big.Int { return ecc.BLS12_381.BaseField() } + +// BLS12381Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 (base 16) +// 52435875175126190479447740508185965837690552500527637822603658699938581184513 (base 10) +// +// This is the scalar field of the BLS12-381 curve. +type BLS12381Fr struct{ fourLimbPrimeField } + +func (fp BLS12381Fr) Modulus() *big.Int { return ecc.BLS12_381.ScalarField() } + +// P256Fp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff (base 16) +// 115792089210356248762697446949407573530086143415290314195533631308867097853951 (base 10) +// +// This is the base field of the P-256 (also SECP256r1) curve. +type P256Fp struct{ fourLimbPrimeField } + +func (P256Fp) Modulus() *big.Int { return elliptic.P256().Params().P } + +// P256Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551 (base 16) +// 115792089210356248762697446949407573529996955224135760342422259061068512044369 (base 10) +// +// This is the base field of the P-256 (also SECP256r1) curve. +type P256Fr struct{ fourLimbPrimeField } + +func (P256Fr) Modulus() *big.Int { return elliptic.P256().Params().N } + +// P384Fp provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff (base 16) +// 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319 (base 10) +// +// This is the base field of the P-384 (also SECP384r1) curve. +type P384Fp struct{ sixLimbPrimeField } + +func (P384Fp) Modulus() *big.Int { return elliptic.P384().Params().P } + +// P384Fr provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973 (base 16) +// 39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643 (base 10) +// +// This is the scalar field of the P-384 (also SECP384r1) curve. +type P384Fr struct{ sixLimbPrimeField } + +func (P384Fr) Modulus() *big.Int { return elliptic.P384().Params().N } diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index e4c0c3d357..124f4e9b37 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -6,7 +6,9 @@ import ( "sync" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/rangecheck" "github.com/rs/zerolog" "golang.org/x/exp/constraints" ) @@ -27,16 +29,19 @@ type Field[T FieldParams] struct { maxOfOnce sync.Once // constants for often used elements n, 0 and 1. Allocated only once - nConstOnce sync.Once - nConst *Element[T] - zeroConstOnce sync.Once - zeroConst *Element[T] - oneConstOnce sync.Once - oneConst *Element[T] + nConstOnce sync.Once + nConst *Element[T] + nprevConstOnce sync.Once + nprevConst *Element[T] + zeroConstOnce sync.Once + zeroConst *Element[T] + oneConstOnce sync.Once + oneConst *Element[T] log zerolog.Logger constrainedLimbs map[uint64]struct{} + checker frontend.Rangechecker } // NewField returns an object to be used in-circuit to perform emulated @@ -52,6 +57,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { api: native, log: logger.Logger(), constrainedLimbs: make(map[uint64]struct{}), + checker: rangecheck.New(native), } // ensure prime is correctly set @@ -89,7 +95,9 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { // NewElement builds a new Element[T] from input v. // - if v is a Element[T] or *Element[T] it clones it // - if v is a constant this is equivalent to calling emulated.ValueOf[T] -// - if this methods interpret v (frontend.Variable or []frontend.Variable) as being the limbs; and constrain the limbs following the parameters of the Field. +// - if this methods interprets v as being the limbs (frontend.Variable or []frontend.Variable), +// it constructs a new Element[T] with v as limbs and constraints the limbs to the parameters +// of the Field[T]. func (f *Field[T]) NewElement(v interface{}) *Element[T] { if e, ok := v.(Element[T]); ok { return e.copy() @@ -101,11 +109,6 @@ func (f *Field[T]) NewElement(v interface{}) *Element[T] { return f.packLimbs([]frontend.Variable{v}, true) } if e, ok := v.([]frontend.Variable); ok { - for _, sv := range e { - if !frontend.IsCanonical(sv) { - panic("[]frontend.Variable that are not canonical (known to the compiler) is not a valid input") - } - } return f.packLimbs(e, true) } c := ValueOf[T](v) @@ -136,6 +139,14 @@ func (f *Field[T]) Modulus() *Element[T] { return f.nConst } +// modulusPrev returns modulus-1 as a constant. +func (f *Field[T]) modulusPrev() *Element[T] { + f.nprevConstOnce.Do(func() { + f.nprevConst = newConstElement[T](new(big.Int).Sub(f.fParams.Modulus(), big.NewInt(1))) + }) + return f.nprevConst +} + // packLimbs returns an element from the given limbs. // If strict is true, the most significant limb will be constrained to have width of the most // significant limb of the modulus, which may have less bits than the other limbs. In which case, @@ -157,14 +168,27 @@ func (f *Field[T]) enforceWidthConditional(a *Element[T]) (didConstrain bool) { return false } if _, isConst := f.constantValue(a); isConst { + // enforce constant element limbs not to be large. + for i := range a.Limbs { + val := utils.FromInterface(a.Limbs[i]) + if val.BitLen() > int(f.fParams.BitsPerLimb()) { + panic("constant element limb wider than emulated parameter") + } + } // constant values are constant return false } for i := range a.Limbs { if !frontend.IsCanonical(a.Limbs[i]) { - // this is not a variable. This may happen when some limbs are - // constant and some variables. A strange case but lets try to cover - // it anyway. + // this is not a canonical variable, nor a constant. This may happen + // when some limbs are constant and some variables. Or if we are + // running in a test engine. In either case, we must check that if + // this limb is a [*big.Int] that its bitwidth is less than the + // NbBits. + val := utils.FromInterface(a.Limbs[i]) + if val.BitLen() > int(f.fParams.BitsPerLimb()) { + panic("non-canonical integer limb wider than emulated parameter") + } continue } if vv, ok := a.Limbs[i].(interface{ HashCode() uint64 }); ok { @@ -259,7 +283,7 @@ func (f *Field[T]) compactLimbs(e *Element[T], groupSize, bitsPerLimb uint) []fr // then the limbs may overflow the native field. func (f *Field[T]) maxOverflow() uint { f.maxOfOnce.Do(func() { - f.maxOf = uint(f.api.Compiler().FieldBitLen()-1) - f.fParams.BitsPerLimb() + f.maxOf = uint(f.api.Compiler().FieldBitLen()-2) - f.fParams.BitsPerLimb() }) return f.maxOf } diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index c8daae2bf5..2a02715900 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -5,13 +5,12 @@ import ( "math/big" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/bits" ) // assertLimbsEqualitySlow is the main routine in the package. It asserts that the // two slices of limbs represent the same integer value. This is also the most // costly operation in the package as it does bit decomposition of the limbs. -func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { +func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { nbLimbs := max(len(l), len(r)) maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) @@ -33,52 +32,29 @@ func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, // carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1] // we know that diff[:nbBits] are 0 bits, but still need to constrain them. // to do both; we do a "clean" right shift and only need to boolean constrain the carry part - carry = rsh(api, diff, int(nbBits), int(nbBits+nbCarryBits+1)) + carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1)) } api.AssertIsEqual(carry, maxValueShift) } -// rsh right shifts a variable endDigit-startDigit bits and returns it. -func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) frontend.Variable { +func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable { // if v is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { + if c, ok := f.api.Compiler().ConstantValue(v); ok { bits := make([]frontend.Variable, endDigit-startDigit) for i := 0; i < len(bits); i++ { bits[i] = c.Bit(i + startDigit) } return bits } - - bits, err := api.Compiler().NewHint(NBitsShifted, endDigit-startDigit, v, startDigit) + shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v) if err != nil { - panic(err) - } - - // we compute 2 sums; - // Σbi ensures that "ignoring" the lowest bits (< startDigit) still is a valid bit decomposition. - // that is, it ensures that bits from startDigit to endDigit * corresponding coefficients (powers of 2 shifted) - // are equal to the input variable - // ΣbiRShift computes the actual result; that is, the Σ (2**i * b[i]) - Σbi := frontend.Variable(0) - ΣbiRShift := frontend.Variable(0) - - cRShift := big.NewInt(1) - c := big.NewInt(1) - c.Lsh(c, uint(startDigit)) - - for i := 0; i < len(bits); i++ { - Σbi = api.MulAcc(Σbi, bits[i], c) - ΣbiRShift = api.MulAcc(ΣbiRShift, bits[i], cRShift) - - c.Lsh(c, 1) - cRShift.Lsh(cRShift, 1) - api.AssertIsBoolean(bits[i]) + panic(fmt.Sprintf("right shift: %v", err)) } - - // constraint Σ (2**i_shift * b[i]) == v - api.AssertIsEqual(Σbi, v) - return ΣbiRShift - + f.checker.Check(shifted[0], endDigit-startDigit) + shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit)) + composed := f.api.Mul(shifted[0], shift) + f.api.AssertIsEqual(composed, v) + return shifted[0] } // AssertLimbsEquality asserts that the limbs represent a same integer value. @@ -107,9 +83,9 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) { // TODO: we previously assumed that one side was "larger" than the other // side, but I think this assumption is not valid anymore if a.overflow > b.overflow { - assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) + f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) } else { - assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) + f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) } } @@ -133,16 +109,14 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { // take only required bits from the most significant limb limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1 } - // bits.ToBinary restricts the least significant NbDigits to be equal to - // the limb value. This is sufficient to restrict for the bitlength and - // we can discard the bits themselves. - bits.ToBinary(f.api, a.Limbs[i], bits.WithNbDigits(limbNbBits)) + f.checker.Check(a.Limbs[i], limbNbBits) } } // AssertIsEqual ensures that a is equal to b modulo the modulus. func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { - // we omit width assertion as it is done in Sub below + f.enforceWidthConditional(a) + f.enforceWidthConditional(b) ba, aConst := f.constantValue(a) bb, bConst := f.constantValue(b) if aConst && bConst { @@ -154,7 +128,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { return } - diff := f.Sub(b, a) + diff := f.subNoReduce(b, a) // we compute k such that diff / p == k // so essentially, we say "I know an element k such that k*p == diff" @@ -170,7 +144,9 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { f.AssertLimbsEquality(diff, kp) } -// AssertIsLessOrEqual ensures that e is less or equal than a. +// AssertIsLessOrEqual ensures that e is less or equal than a. For proper +// bitwise comparison first reduce the element using [Reduce] and then assert +// that its value is less than the modulus using [AssertIsInRange]. func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { // we omit conditional width assertion as is done in ToBits below if e.overflow+a.overflow > 0 { @@ -181,7 +157,7 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { ff := func(xbits, ybits []frontend.Variable) []frontend.Variable { diff := len(xbits) - len(ybits) ybits = append(ybits, make([]frontend.Variable, diff)...) - for i := len(ybits) - diff - 1; i < len(ybits); i++ { + for i := len(ybits) - diff; i < len(ybits); i++ { ybits[i] = 0 } return ybits @@ -202,3 +178,62 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { f.api.AssertIsEqual(ll, 0) } } + +// AssertIsInRange ensures that a is less than the emulated modulus. When we +// call [Reduce] then we only ensure that the result is width-constrained, but +// not actually less than the modulus. This means that the actual value may be +// either x or x + p. For arithmetic it is sufficient, but for binary comparison +// it is not. For binary comparison the values have both to be below the +// modulus. +func (f *Field[T]) AssertIsInRange(a *Element[T]) { + // we omit conditional width assertion as is done in ToBits down the calling stack + f.AssertIsLessOrEqual(a, f.modulusPrev()) +} + +// IsZero returns a boolean indicating if the element is strictly zero. The +// method internally reduces the element and asserts that the value is less than +// the modulus. +func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { + ca := f.Reduce(a) + f.AssertIsInRange(ca) + res := f.api.IsZero(ca.Limbs[0]) + for i := 1; i < len(ca.Limbs); i++ { + f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + } + return res +} + +// // Cmp returns: +// // - -1 if a < b +// // - 0 if a = b +// // - 1 if a > b +// // +// // The method internally reduces the element and asserts that the value is less +// // than the modulus. +// func (f *Field[T]) Cmp(a, b *Element[T]) frontend.Variable { +// ca := f.Reduce(a) +// f.AssertIsInRange(ca) +// cb := f.Reduce(b) +// f.AssertIsInRange(cb) +// var res frontend.Variable = 0 +// for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { +// lmbCmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) +// res = f.api.Select(f.api.IsZero(res), lmbCmp, res) +// } +// return res +// } + +// TODO(@ivokub) +// func (f *Field[T]) AssertIsDifferent(a, b *Element[T]) { +// ca := f.Reduce(a) +// f.AssertIsInRange(ca) +// cb := f.Reduce(b) +// f.AssertIsInRange(cb) +// var res frontend.Variable = 0 +// for i := 0; i < int(f.fParams.NbLimbs()); i++ { +// cmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) +// cmpsq := f.api.Mul(cmp, cmp) +// res = f.api.Add(res, cmpsq) +// } +// f.api.AssertIsDifferent(res, 0) +// } diff --git a/std/math/emulated/field_binary.go b/std/math/emulated/field_binary.go index 213f1654c2..8c949af2e4 100644 --- a/std/math/emulated/field_binary.go +++ b/std/math/emulated/field_binary.go @@ -14,7 +14,11 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable { f.enforceWidthConditional(a) ba, aConst := f.constantValue(a) if aConst { - return f.api.ToBinary(ba, int(f.fParams.BitsPerLimb()*f.fParams.NbLimbs())) + res := make([]frontend.Variable, f.fParams.BitsPerLimb()*f.fParams.NbLimbs()) + for i := range res { + res[i] = ba.Bit(i) + } + return res } var carry frontend.Variable = 0 var fullBits []frontend.Variable diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go new file mode 100644 index 0000000000..52b8e9ddfc --- /dev/null +++ b/std/math/emulated/field_hint.go @@ -0,0 +1,109 @@ +package emulated + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" +) + +func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable { + res := []frontend.Variable{f.fParams.BitsPerLimb(), f.fParams.NbLimbs()} + res = append(res, f.Modulus().Limbs...) + res = append(res, len(nonnativeInputs)) + for i := range nonnativeInputs { + res = append(res, len(nonnativeInputs[i].Limbs)) + res = append(res, nonnativeInputs[i].Limbs...) + } + return res +} + +// UnwrapHint unwraps the native inputs into nonnative inputs. Then it calls +// nonnativeHint function with nonnative inputs. After nonnativeHint returns, it +// decomposes the outputs into limbs. +func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + if len(nativeInputs) < 2 { + return fmt.Errorf("hint wrapper header is 2 elements") + } + if !nativeInputs[0].IsInt64() || !nativeInputs[1].IsInt64() { + return fmt.Errorf("header must be castable to int64") + } + nbBits := int(nativeInputs[0].Int64()) + nbLimbs := int(nativeInputs[1].Int64()) + if len(nativeInputs) < 2+nbLimbs { + return fmt.Errorf("hint wrapper header is 2+nbLimbs elements") + } + nonnativeMod := new(big.Int) + if err := recompose(nativeInputs[2:2+nbLimbs], uint(nbBits), nonnativeMod); err != nil { + return fmt.Errorf("cannot recover nonnative mod: %w", err) + } + if !nativeInputs[2+nbLimbs].IsInt64() { + return fmt.Errorf("number of nonnative elements must be castable to int64") + } + nbInputs := int(nativeInputs[2+nbLimbs].Int64()) + nonnativeInputs := make([]*big.Int, nbInputs) + readPtr := 3 + nbLimbs + for i := 0; i < nbInputs; i++ { + if len(nativeInputs) < readPtr+1 { + return fmt.Errorf("can not read %d-th native input", i) + } + if !nativeInputs[readPtr].IsInt64() { + return fmt.Errorf("corrupted %d-th native input", i) + } + currentInputLen := int(nativeInputs[readPtr].Int64()) + if len(nativeInputs) < (readPtr + 1 + currentInputLen) { + return fmt.Errorf("cannot read %d-th nonnative element", i) + } + nonnativeInputs[i] = new(big.Int) + if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil { + return fmt.Errorf("recompose %d-th element: %w", i, err) + } + readPtr += 1 + currentInputLen + } + if len(nativeOutputs)%nbLimbs != 0 { + return fmt.Errorf("output count doesn't divide limb count") + } + nonnativeOutputs := make([]*big.Int, len(nativeOutputs)/nbLimbs) + for i := range nonnativeOutputs { + nonnativeOutputs[i] = new(big.Int) + } + if err := nonnativeHint(nonnativeMod, nonnativeInputs, nonnativeOutputs); err != nil { + return fmt.Errorf("nonnative hint: %w", err) + } + for i := range nonnativeOutputs { + nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod) + if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil { + return fmt.Errorf("decompose %d-th element: %w", i, err) + } + } + return nil +} + +// NewHint allows to call the emulation hint function hf on inputs, expecting +// nbOutputs results. This function splits internally the emulated element into +// limbs and passes these to the hint function. There is [UnwrapHint] function +// which performs corresponding recomposition of limbs into integer values (and +// vice verse for output). +// +// The hint function for this method is defined as: +// +// func HintFn(mod *big.Int, inputs, outputs []*big.Int) error { +// return emulated.UnwrapHint(inputs, outputs, func(mod *big.Int, inputs, outputs []*big.Int) error { // NB we shadow initial input, output, mod to avoid accidental overwrite! +// // here all inputs and outputs are modulo nonnative mod. we decompose them automatically +// })} +// +// See the example for full written example. +func (f *Field[T]) NewHint(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) ([]*Element[T], error) { + nativeInputs := f.wrapHint(inputs...) + nbNativeOutputs := int(f.fParams.NbLimbs()) * nbOutputs + nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...) + if err != nil { + return nil, fmt.Errorf("call hint: %w", err) + } + outputs := make([]*Element[T], nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = f.packLimbs(nativeOutputs[i*int(f.fParams.NbLimbs()):(i+1)*int(f.fParams.NbLimbs())], true) + } + return outputs, nil +} diff --git a/std/math/emulated/field_hint_test.go b/std/math/emulated/field_hint_test.go new file mode 100644 index 0000000000..9f2d286ec8 --- /dev/null +++ b/std/math/emulated/field_hint_test.go @@ -0,0 +1,116 @@ +package emulated_test + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/math/emulated" +) + +// HintExample is a hint for field emulation which returns the division of the +// first and second input. +func HintExample(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + // nativeInputs are the limbs of the input non-native elements. We wrap the + // actual hint function with [emulated.UnwrapHint] to get actual [*big.Int] + // values of the non-native elements. + return emulated.UnwrapHint(nativeInputs, nativeOutputs, func(mod *big.Int, inputs, outputs []*big.Int) error { + // this hint computes the division of first and second input and returns it. + nominator := inputs[0] + denominator := inputs[1] + res := new(big.Int).ModInverse(denominator, mod) + if res == nil { + return fmt.Errorf("no modular inverse") + } + res.Mul(res, nominator) + res.Mod(res, mod) + outputs[0].Set(res) + return nil + }) + // when the internal hint function returns, the UnwrapHint function + // decomposes the non-native value into limbs. +} + +type emulationHintCircuit[T emulated.FieldParams] struct { + Nominator emulated.Element[T] + Denominator emulated.Element[T] + Expected emulated.Element[T] +} + +func (c *emulationHintCircuit[T]) Define(api frontend.API) error { + field, err := emulated.NewField[T](api) + if err != nil { + return err + } + res, err := field.NewHint(HintExample, 1, &c.Nominator, &c.Denominator) + if err != nil { + return err + } + m := field.Mul(res[0], &c.Denominator) + field.AssertIsEqual(m, &c.Nominator) + field.AssertIsEqual(res[0], &c.Expected) + return nil +} + +// Example of using hints with emulated elements. +func ExampleField_NewHint() { + var a, b, c fr.Element + a.SetRandom() + b.SetRandom() + c.Div(&a, &b) + + circuit := emulationHintCircuit[emulated.BN254Fr]{} + witness := emulationHintCircuit[emulated.BN254Fr]{ + Nominator: emulated.ValueOf[emulated.BN254Fr](a), + Denominator: emulated.ValueOf[emulated.BN254Fr](b), + Expected: emulated.ValueOf[emulated.BN254Fr](c), + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + witnessData, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness parsed") + } + publicWitnessData, err := witnessData.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness parsed") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + proof, err := groth16.Prove(ccs, pk, witnessData, backend.WithSolverOptions(solver.WithHints(HintExample))) + if err != nil { + panic(err) + } else { + fmt.Println("proved") + } + err = groth16.Verify(proof, vk, publicWitnessData) + if err != nil { + panic(err) + } else { + fmt.Println("verified") + } + // Output: compiled + // secret witness parsed + // public witness parsed + // setup done + // proved + // verified +} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index baa7154f5d..26125e7eaf 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -44,6 +44,21 @@ func (f *Field[T]) Inverse(a *Element[T]) *Element[T] { return e } +// Sqrt computes square root of a and returns it. It uses [SqrtHint]. +func (f *Field[T]) Sqrt(a *Element[T]) *Element[T] { + // omit width assertion as is done in Mul below + if !f.fParams.IsPrime() { + panic("modulus not a prime") + } + res, err := f.NewHint(SqrtHint, 1, a) + if err != nil { + panic(fmt.Sprintf("compute sqrt: %v", err)) + } + _a := f.Mul(res[0], res[0]) + f.AssertIsEqual(_a, a) + return res[0] +} + // Add computes a+b and returns it. If the result wouldn't fit into Element, then // first reduces the inputs (larger first) and tries again. Doesn't mutate // inputs. @@ -178,19 +193,19 @@ func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { w := new(big.Int) for c := 1; c <= len(mulResult); c++ { w.SetInt64(1) // c^i - l := a.Limbs[0] - r := b.Limbs[0] - o := mulResult[0] + l := f.api.Mul(a.Limbs[0], 1) + r := f.api.Mul(b.Limbs[0], 1) + o := f.api.Mul(mulResult[0], 1) for i := 1; i < len(mulResult); i++ { w.Lsh(w, uint(c)) if i < len(a.Limbs) { - l = f.api.Add(l, f.api.Mul(a.Limbs[i], w)) + l = f.api.MulAcc(l, a.Limbs[i], w) } if i < len(b.Limbs) { - r = f.api.Add(r, f.api.Mul(b.Limbs[i], w)) + r = f.api.MulAcc(r, b.Limbs[i], w) } - o = f.api.Add(o, f.api.Mul(mulResult[i], w)) + o = f.api.MulAcc(o, mulResult[i], w) } f.api.AssertIsEqual(f.api.Mul(l, r), o) } @@ -214,7 +229,6 @@ func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { if err != nil { panic(fmt.Sprintf("reduction hint: %v", err)) } - // TODO @gbotrel fixme: AssertIsEqual(a, e) crashes Pairing test f.AssertIsEqual(e, a) return e } @@ -225,9 +239,20 @@ func (f *Field[T]) Sub(a, b *Element[T]) *Element[T] { return f.reduceAndOp(f.sub, f.subPreCond, a, b) } +// subReduce returns a-b and returns it. Contrary to [Field[T].Sub] method this +// method does not reduce the inputs if the result would overflow. This method +// is currently only used as a subroutine in [Field[T].Reduce] method to avoid +// infinite recursion when we are working exactly on the overflow limits. +func (f *Field[T]) subNoReduce(a, b *Element[T]) *Element[T] { + nextOverflow, _ := f.subPreCond(a, b) + // we ignore error as it only indicates if we should reduce or not. But we + // are in non-reducing version of sub. + return f.sub(a, b, nextOverflow) +} + func (f *Field[T]) subPreCond(a, b *Element[T]) (nextOverflow uint, err error) { - reduceRight := a.overflow < b.overflow+2 - nextOverflow = max(b.overflow+2, a.overflow) + reduceRight := a.overflow < (b.overflow + 1) + nextOverflow = max(b.overflow+1, a.overflow) + 1 if nextOverflow > f.maxOverflow() { err = overflowError{op: "sub", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} } @@ -263,7 +288,7 @@ func (f *Field[T]) Neg(a *Element[T]) *Element[T] { return f.Sub(f.Zero(), a) } -// Select sets e to a if selector == 0 and to b otherwise. Sets the number of +// Select sets e to a if selector == 1 and to b otherwise. Sets the number of // limbs and overflow of the result to be the maximum of the limb lengths and // overflows. If the inputs are strongly unbalanced, then it would better to // reduce the result after the operation. diff --git a/std/math/emulated/field_test.go b/std/math/emulated/field_test.go index 927b23bc59..d1124f4e82 100644 --- a/std/math/emulated/field_test.go +++ b/std/math/emulated/field_test.go @@ -8,7 +8,7 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/test" ) diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 9e799e4e17..e0aeff00ab 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -4,7 +4,7 @@ import ( "fmt" "math/big" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -12,18 +12,19 @@ import ( // inside a func, then it becomes anonymous and hint identification is screwed. func init() { - hint.Register(GetHints()...) + solver.RegisterHint(GetHints()...) } // GetHints returns all hint functions used in the package. -func GetHints() []hint.Function { - return []hint.Function{ +func GetHints() []solver.Hint { + return []solver.Hint{ DivHint, QuoHint, InverseHint, MultiplicationHint, RemHint, - NBitsShifted, + RightShift, + SqrtHint, } } @@ -287,13 +288,38 @@ func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error return nbBits, nbLimbs, x, y, nil } -// NBitsShifted returns the first bits of the input, with a shift. The number of returned bits is -// defined by the length of the results slice. -func NBitsShifted(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - shift := inputs[1].Uint64() // TODO @gbotrel validate input vs perf in large circuits. - for i := 0; i < len(results); i++ { - results[i].SetUint64(uint64(n.Bit(i + int(shift)))) - } +// RightShift shifts input by the given number of bits. Expects two inputs: +// - first input is the shift, will be represented as uint64; +// - second input is the value to be shifted. +// +// Returns a single output which is the value shifted. Errors if number of +// inputs is not 2 and number of outputs is not 1. +func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two inputs") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting single output") + } + shift := inputs[0].Uint64() + outputs[0].Rsh(inputs[1], uint(shift)) return nil } + +// SqrtHint compute square root of the input. +func SqrtHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 1 { + return fmt.Errorf("expecting single input") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting single output") + } + res := new(big.Int) + if res.ModSqrt(inputs[0], field) == nil { + return fmt.Errorf("no square root") + } + outputs[0].Set(res) + return nil + }) +} diff --git a/std/math/emulated/params.go b/std/math/emulated/params.go index 9f89b0da09..fc3d8abfc1 100644 --- a/std/math/emulated/params.go +++ b/std/math/emulated/params.go @@ -3,10 +3,20 @@ package emulated import ( "math/big" - "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/std/math/emulated/emparams" ) -// FieldParams describes the emulated field characteristics +// FieldParams describes the emulated field characteristics. For a list of +// included built-in emulation params refer to the [emparams] package. +// For backwards compatibility, the current package contains the following +// parameters: +// - [Goldilocks] +// - [Secp256k1Fp] and [Secp256k1Fr] +// - [BN254Fp] and [BN254Fr] +// - [BLS12377Fp] +// - [BLS12381Fp] and [BLS12381Fr] +// - [P256Fp] and [P256Fr] +// - [P384Fp] and [P384Fr] type FieldParams interface { NbLimbs() uint // number of limbs to represent field element BitsPerLimb() uint // number of bits per limb. Top limb may contain less than limbSize bits. @@ -14,75 +24,17 @@ type FieldParams interface { Modulus() *big.Int // returns modulus. Do not modify. } -var ( - qSecp256k1, rSecp256k1 *big.Int - qGoldilocks *big.Int +type ( + Goldilocks = emparams.Goldilocks + Secp256k1Fp = emparams.Secp256k1Fp + Secp256k1Fr = emparams.Secp256k1Fr + BN254Fp = emparams.BN254Fp + BN254Fr = emparams.BN254Fr + BLS12377Fp = emparams.BLS12377Fp + BLS12381Fp = emparams.BLS12381Fp + BLS12381Fr = emparams.BLS12381Fr + P256Fp = emparams.P256Fp + P256Fr = emparams.P256Fr + P384Fp = emparams.P384Fp + P384Fr = emparams.P384Fr ) - -func init() { - qSecp256k1, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) - rSecp256k1, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) - qGoldilocks, _ = new(big.Int).SetString("ffffffff00000001", 16) -} - -// Goldilocks provide type parametrization for emulated field on 1 limb of width 64bits -// for modulus 0xffffffff00000001 -type Goldilocks struct{} - -func (fp Goldilocks) NbLimbs() uint { return 1 } -func (fp Goldilocks) BitsPerLimb() uint { return 64 } -func (fp Goldilocks) IsPrime() bool { return true } -func (fp Goldilocks) Modulus() *big.Int { return qGoldilocks } - -// Secp256k1Fp provide type parametrization for emulated field on 4 limb of width 64bits -// for modulus 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. -// This is the base field of secp256k1 curve -type Secp256k1Fp struct{} - -func (fp Secp256k1Fp) NbLimbs() uint { return 4 } -func (fp Secp256k1Fp) BitsPerLimb() uint { return 64 } -func (fp Secp256k1Fp) IsPrime() bool { return true } -func (fp Secp256k1Fp) Modulus() *big.Int { return qSecp256k1 } - -// Secp256k1Fr provides type parametrization for emulated field on 4 limbs of width 64bits -// for modulus 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141. -// This is the scalar field of secp256k1 curve. -type Secp256k1Fr struct{} - -func (fp Secp256k1Fr) NbLimbs() uint { return 4 } -func (fp Secp256k1Fr) BitsPerLimb() uint { return 64 } -func (fp Secp256k1Fr) IsPrime() bool { return true } -func (fp Secp256k1Fr) Modulus() *big.Int { return rSecp256k1 } - -// BN254Fp provide type parametrization for emulated field on 4 limb of width -// 64bits for modulus -// 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47. This is -// the base field of the BN254 curve. -type BN254Fp struct{} - -func (fp BN254Fp) NbLimbs() uint { return 4 } -func (fp BN254Fp) BitsPerLimb() uint { return 64 } -func (fp BN254Fp) IsPrime() bool { return true } -func (fp BN254Fp) Modulus() *big.Int { return ecc.BN254.BaseField() } - -// BN254Fr provides type parametrisation for emulated field on 4 limbs of width -// 64bits for modulus -// 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001. This is -// the scalar field of the BN254 curve. -type BN254Fr struct{} - -func (fp BN254Fr) NbLimbs() uint { return 4 } -func (fp BN254Fr) BitsPerLimb() uint { return 64 } -func (fp BN254Fr) IsPrime() bool { return true } -func (fp BN254Fr) Modulus() *big.Int { return ecc.BN254.ScalarField() } - -// BLS12377Fp provide type parametrization for emulated field on 6 limb of width -// 64bits for modulus -// 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001. -// This is the base field of the BLS12-377 curve. -type BLS12377Fp struct{} - -func (fp BLS12377Fp) NbLimbs() uint { return 6 } -func (fp BLS12377Fp) BitsPerLimb() uint { return 64 } -func (fp BLS12377Fp) IsPrime() bool { return true } -func (fp BLS12377Fp) Modulus() *big.Int { return ecc.BLS12_377.BaseField() } diff --git a/std/math/uints/hints.go b/std/math/uints/hints.go new file mode 100644 index 0000000000..40dbbd9917 --- /dev/null +++ b/std/math/uints/hints.go @@ -0,0 +1,53 @@ +package uints + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +func GetHints() []solver.Hint { + return []solver.Hint{ + andHint, + xorHint, + toBytes, + } +} + +func xorHint(_ *big.Int, inputs, outputs []*big.Int) error { + outputs[0].Xor(inputs[0], inputs[1]) + return nil +} + +func andHint(_ *big.Int, inputs, outputs []*big.Int) error { + outputs[0].And(inputs[0], inputs[1]) + return nil +} + +func toBytes(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("input must be 2 elements") + } + if !inputs[0].IsUint64() { + return fmt.Errorf("first input must be uint64") + } + nbLimbs := int(inputs[0].Uint64()) + if len(outputs) != nbLimbs { + return fmt.Errorf("output must be 8 elements") + } + if !inputs[1].IsUint64() { + return fmt.Errorf("input must be 64 bits") + } + base := new(big.Int).Lsh(big.NewInt(1), uint(8)) + tmp := new(big.Int).Set(inputs[1]) + for i := 0; i < nbLimbs; i++ { + outputs[i].Mod(tmp, base) + tmp.Rsh(tmp, 8) + } + return nil +} diff --git a/std/math/uints/uint8.go b/std/math/uints/uint8.go new file mode 100644 index 0000000000..cec591d10c --- /dev/null +++ b/std/math/uints/uint8.go @@ -0,0 +1,342 @@ +// Package uints implements optimised byte and long integer operations. +// +// Usually arithmetic in a circuit is performed in the native field, which is of +// prime order. However, for compatibility with native operations we rely on +// operating on smaller primitive types as 8-bit, 32-bit and 64-bit integer. +// Naively, these operations have to be implemented bitwise as there are no +// closed equations for boolean operations (XOR, AND, OR). +// +// However, the bitwise approach is very inefficient and leads to several +// constraints per bit. Accumulating over a long integer, it leads to very +// inefficients circuits. +// +// This package performs boolean operations using lookup tables on bytes. So, +// long integers are split into 4 or 8 bytes and we perform the operations +// bytewise. In the lookup tables, we store results for all possible 2^8×2^8 +// inputs. With this approach, every bytewise operation costs as single lookup, +// which depending on the backend is relatively cheap (one to three +// constraints). +// +// NB! The package is still work in progress. The interfaces and implementation +// details most certainly changes over time. We cannot ensure the soundness of +// the operations. +package uints + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/internal/logderivprecomp" + "github.com/consensys/gnark/std/math/bitslice" + "github.com/consensys/gnark/std/rangecheck" +) + +// TODO: if internal then enforce range check! + +// TODO: all operations can take rand linear combinations instead. Then instead +// of one check can perform multiple at the same time. + +// TODO: implement versions which take multiple inputs. Maybe can combine multiple together + +// TODO: instantiate tables only when we first query. Maybe do not need to build! + +// TODO: maybe can store everything in a single table? Later! Or if we have a +// lot of queries then makes sense to extract into separate table? + +// TODO: in ValueOf ensure consistency + +// TODO: distinguish between when we set constant in-circuit or witness +// assignment. For constant we don't have to range check but for witness +// assignment we have to. + +// TODO: add something which allows to store array in native element + +// TODO: add methods for checking if U8/Long is constant. + +// TODO: should something for byte-only ops. Implement a type and then embed it in BinaryField + +// TODO: add helper method to call hints which allows to pass in uint8s (bytes) +// and returns bytes. Then can to byte array manipluation nicely. It is useful +// for X509. For the implementation we want to pack as much bytes into a field +// element as possible. + +// TODO: methods for converting uint array into emulated element and native +// element. Most probably should add the implementation for non-native in its +// package, but for native we should add it here. + +type U8 struct { + Val frontend.Variable + internal bool +} + +// GnarkInitHook describes how to initialise the element. +func (e *U8) GnarkInitHook() { + if e.Val == nil { + e.Val = 0 + e.internal = false // we need to constrain in later. + } +} + +type U64 [8]U8 +type U32 [4]U8 + +type Long interface{ U32 | U64 } + +type BinaryField[T U32 | U64] struct { + api frontend.API + xorT, andT *logderivprecomp.Precomputed + rchecker frontend.Rangechecker + allOne U8 +} + +func New[T Long](api frontend.API) (*BinaryField[T], error) { + xorT, err := logderivprecomp.New(api, xorHint, []uint{8}) + if err != nil { + return nil, fmt.Errorf("new xor table: %w", err) + } + andT, err := logderivprecomp.New(api, andHint, []uint{8}) + if err != nil { + return nil, fmt.Errorf("new and table: %w", err) + } + rchecker := rangecheck.New(api) + bf := &BinaryField[T]{ + api: api, + xorT: xorT, + andT: andT, + rchecker: rchecker, + } + // TODO: this is const. add way to init constants + allOne := bf.ByteValueOf(0xff) + bf.allOne = allOne + return bf, nil +} + +func NewU8(v uint8) U8 { + // TODO: don't have to check constants + return U8{Val: v, internal: true} +} + +func NewU32(v uint32) U32 { + return [4]U8{ + NewU8(uint8((v >> (0 * 8)) & 0xff)), + NewU8(uint8((v >> (1 * 8)) & 0xff)), + NewU8(uint8((v >> (2 * 8)) & 0xff)), + NewU8(uint8((v >> (3 * 8)) & 0xff)), + } +} + +func NewU64(v uint64) U64 { + return [8]U8{ + NewU8(uint8((v >> (0 * 8)) & 0xff)), + NewU8(uint8((v >> (1 * 8)) & 0xff)), + NewU8(uint8((v >> (2 * 8)) & 0xff)), + NewU8(uint8((v >> (3 * 8)) & 0xff)), + NewU8(uint8((v >> (4 * 8)) & 0xff)), + NewU8(uint8((v >> (5 * 8)) & 0xff)), + NewU8(uint8((v >> (6 * 8)) & 0xff)), + NewU8(uint8((v >> (7 * 8)) & 0xff)), + } +} + +func NewU8Array(v []uint8) []U8 { + ret := make([]U8, len(v)) + for i := range v { + ret[i] = NewU8(v[i]) + } + return ret +} + +func NewU32Array(v []uint32) []U32 { + ret := make([]U32, len(v)) + for i := range v { + ret[i] = NewU32(v[i]) + } + return ret +} + +func NewU64Array(v []uint64) []U64 { + ret := make([]U64, len(v)) + for i := range v { + ret[i] = NewU64(v[i]) + } + return ret +} + +func (bf *BinaryField[T]) ByteValueOf(a frontend.Variable) U8 { + bf.rchecker.Check(a, 8) + return U8{Val: a, internal: true} +} + +func (bf *BinaryField[T]) ValueOf(a frontend.Variable) T { + var r T + bts, err := bf.api.Compiler().NewHint(toBytes, len(r), len(r), a) + if err != nil { + panic(err) + } + // TODO: add constraint which ensures that map back to + for i := range bts { + r[i] = bf.ByteValueOf(bts[i]) + } + return r +} + +func (bf *BinaryField[T]) ToValue(a T) frontend.Variable { + v := make([]frontend.Variable, len(a)) + for i := range v { + v[i] = bf.api.Mul(a[i].Val, 1<<(i*8)) + } + vv := bf.api.Add(v[0], v[1], v[2:]...) + return vv +} + +func (bf *BinaryField[T]) PackMSB(a ...U8) T { + var ret T + for i := range a { + ret[len(a)-i-1] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) PackLSB(a ...U8) T { + var ret T + for i := range a { + ret[i] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) UnpackMSB(a T) []U8 { + ret := make([]U8, len(a)) + for i := 0; i < len(a); i++ { + ret[len(a)-i-1] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) UnpackLSB(a T) []U8 { + // cannot deduce that a can be cast to []U8 + ret := make([]U8, len(a)) + for i := 0; i < len(a); i++ { + ret[i] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) twoArgFn(tbl *logderivprecomp.Precomputed, a ...U8) U8 { + ret := tbl.Query(a[0].Val, a[1].Val)[0] + for i := 2; i < len(a); i++ { + ret = tbl.Query(ret, a[i].Val)[0] + } + return U8{Val: ret} +} + +func (bf *BinaryField[T]) twoArgWideFn(tbl *logderivprecomp.Precomputed, a ...T) T { + var r T + for i, v := range reslice(a) { + r[i] = bf.twoArgFn(tbl, v...) + } + return r +} + +func (bf *BinaryField[T]) And(a ...T) T { return bf.twoArgWideFn(bf.andT, a...) } +func (bf *BinaryField[T]) Xor(a ...T) T { return bf.twoArgWideFn(bf.xorT, a...) } + +func (bf *BinaryField[T]) not(a U8) U8 { + ret := bf.xorT.Query(a.Val, bf.allOne.Val) + return U8{Val: ret[0]} +} + +func (bf *BinaryField[T]) Not(a T) T { + var r T + for i := 0; i < len(a); i++ { + r[i] = bf.not(a[i]) + } + return r +} + +func (bf *BinaryField[T]) Add(a ...T) T { + va := make([]frontend.Variable, len(a)) + for i := range a { + va[i] = bf.ToValue(a[i]) + } + vres := bf.api.Add(va[0], va[1], va[2:]...) + res := bf.ValueOf(vres) + // TODO: should also check the that carry we omitted is correct. + return res +} + +func (bf *BinaryField[T]) Lrot(a T, c int) T { + l := len(a) + if c < 0 { + c = l*8 + c + } + shiftBl := c / 8 + shiftBt := c % 8 + revShiftBt := 8 - shiftBt + if revShiftBt == 8 { + revShiftBt = 0 + } + partitioned := make([][2]frontend.Variable, l) + for i := range partitioned { + lower, upper := bitslice.Partition(bf.api, a[i].Val, uint(revShiftBt), bitslice.WithNbDigits(8)) + partitioned[i] = [2]frontend.Variable{lower, upper} + } + var ret T + for i := 0; i < l; i++ { + if shiftBt != 0 { + ret[(i+shiftBl)%l].Val = bf.api.Add(bf.api.Mul(1<<(shiftBt), partitioned[i][0]), partitioned[(i+l-1)%l][1]) + } else { + ret[(i+shiftBl)%l].Val = partitioned[i][1] + } + } + return ret +} + +func (bf *BinaryField[T]) Rshift(a T, c int) T { + shiftBl := c / 8 + shiftBt := c % 8 + partitioned := make([][2]frontend.Variable, len(a)-shiftBl) + for i := range partitioned { + lower, upper := bitslice.Partition(bf.api, a[i+shiftBl].Val, uint(shiftBt), bitslice.WithNbDigits(8)) + partitioned[i] = [2]frontend.Variable{lower, upper} + } + var ret T + for i := 0; i < len(a)-shiftBl-1; i++ { + if shiftBt != 0 { + ret[i].Val = bf.api.Add(partitioned[i][1], bf.api.Mul(1<<(8-shiftBt), partitioned[i+1][0])) + } else { + ret[i].Val = partitioned[i][1] + } + } + ret[len(a)-shiftBl-1].Val = partitioned[len(a)-shiftBl-1][1] + for i := len(a) - shiftBl; i < len(ret); i++ { + ret[i] = NewU8(0) + } + return ret +} + +func (bf *BinaryField[T]) ByteAssertEq(a, b U8) { + bf.api.AssertIsEqual(a.Val, b.Val) +} + +func (bf *BinaryField[T]) AssertEq(a, b T) { + for i := 0; i < len(a); i++ { + bf.ByteAssertEq(a[i], b[i]) + } +} + +func reslice[T U32 | U64](in []T) [][]U8 { + if len(in) == 0 { + panic("zero-length input") + } + ret := make([][]U8, len(in[0])) + for i := range ret { + ret[i] = make([]U8, len(in)) + } + for i := 0; i < len(in); i++ { + for j := 0; j < len(in[0]); j++ { + ret[j][i] = in[i][j] + } + } + return ret +} diff --git a/std/math/uints/uint8_test.go b/std/math/uints/uint8_test.go new file mode 100644 index 0000000000..bc24a8f4db --- /dev/null +++ b/std/math/uints/uint8_test.go @@ -0,0 +1,81 @@ +package uints + +import ( + "math/bits" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type lrotCirc struct { + In U32 + Out U32 + Shift int +} + +func (c *lrotCirc) Define(api frontend.API) error { + uapi, err := New[U32](api) + if err != nil { + return err + } + res := uapi.Lrot(c.In, c.Shift) + uapi.AssertEq(c.Out, res) + return nil +} + +func TestLeftRotation(t *testing.T) { + assert := test.NewAssert(t) + var err error + err = test.IsSolved(&lrotCirc{Shift: 4}, &lrotCirc{In: NewU32(0x12345678), Shift: 4, Out: NewU32(bits.RotateLeft32(0x12345678, 4))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: 14}, &lrotCirc{In: NewU32(0x12345678), Shift: 14, Out: NewU32(bits.RotateLeft32(0x12345678, 14))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: 3}, &lrotCirc{In: NewU32(0x12345678), Shift: 3, Out: NewU32(bits.RotateLeft32(0x12345678, 3))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: 11}, &lrotCirc{In: NewU32(0x12345678), Shift: 11, Out: NewU32(bits.RotateLeft32(0x12345678, 11))}, ecc.BN254.ScalarField()) + assert.NoError(err) + // full block + err = test.IsSolved(&lrotCirc{Shift: 16}, &lrotCirc{In: NewU32(0x12345678), Shift: 16, Out: NewU32(bits.RotateLeft32(0x12345678, 16))}, ecc.BN254.ScalarField()) + assert.NoError(err) + // negative rotations + err = test.IsSolved(&lrotCirc{Shift: -4}, &lrotCirc{In: NewU32(0x12345678), Shift: -4, Out: NewU32(bits.RotateLeft32(0x12345678, -4))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -14}, &lrotCirc{In: NewU32(0x12345678), Shift: -14, Out: NewU32(bits.RotateLeft32(0x12345678, -14))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -3}, &lrotCirc{In: NewU32(0x12345678), Shift: -3, Out: NewU32(bits.RotateLeft32(0x12345678, -3))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -11}, &lrotCirc{In: NewU32(0x12345678), Shift: -11, Out: NewU32(bits.RotateLeft32(0x12345678, -11))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -16}, &lrotCirc{In: NewU32(0x12345678), Shift: -16, Out: NewU32(bits.RotateLeft32(0x12345678, -16))}, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type rshiftCircuit struct { + In, Expected U32 + Shift int +} + +func (c *rshiftCircuit) Define(api frontend.API) error { + uapi, err := New[U32](api) + if err != nil { + return err + } + res := uapi.Rshift(c.In, c.Shift) + uapi.AssertEq(res, c.Expected) + return nil +} + +func TestRshift(t *testing.T) { + assert := test.NewAssert(t) + var err error + err = test.IsSolved(&rshiftCircuit{Shift: 4}, &rshiftCircuit{Shift: 4, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 4)}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&rshiftCircuit{Shift: 12}, &rshiftCircuit{Shift: 12, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 12)}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&rshiftCircuit{Shift: 3}, &rshiftCircuit{Shift: 3, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 3)}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&rshiftCircuit{Shift: 11}, &rshiftCircuit{Shift: 11, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 11)}, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/permutation/keccakf/keccak_test.go b/std/permutation/keccakf/keccak_test.go index fe4fca829d..b7151a893c 100644 --- a/std/permutation/keccakf/keccak_test.go +++ b/std/permutation/keccakf/keccak_test.go @@ -16,7 +16,13 @@ type keccakfCircuit struct { } func (c *keccakfCircuit) Define(api frontend.API) error { - res := keccakf.Permute(api, c.In) + var res [25]frontend.Variable + for i := range res { + res[i] = c.In[i] + } + for i := 0; i < 2; i++ { + res = keccakf.Permute(api, res) + } for i := range res { api.AssertIsEqual(res[i], c.Expected[i]) } @@ -25,15 +31,22 @@ func (c *keccakfCircuit) Define(api frontend.API) error { func TestKeccakf(t *testing.T) { var nativeIn [25]uint64 + var res [25]uint64 for i := range nativeIn { nativeIn[i] = 2 + res[i] = 2 + } + for i := 0; i < 2; i++ { + res = keccakF1600(res) } - nativeOut := keccakF1600(nativeIn) witness := keccakfCircuit{} for i := range nativeIn { witness.In[i] = nativeIn[i] - witness.Expected[i] = nativeOut[i] + witness.Expected[i] = res[i] } assert := test.NewAssert(t) - assert.ProverSucceeded(&keccakfCircuit{}, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&keccakfCircuit{}, &witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16, backend.PLONK), + test.NoFuzzing()) } diff --git a/std/permutation/keccakf/keccakf.go b/std/permutation/keccakf/keccakf.go index fbaedebb72..8f5e3ae346 100644 --- a/std/permutation/keccakf/keccakf.go +++ b/std/permutation/keccakf/keccakf.go @@ -12,33 +12,34 @@ package keccakf import ( "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" ) -var rc = [24]xuint64{ - constUint64(0x0000000000000001), - constUint64(0x0000000000008082), - constUint64(0x800000000000808A), - constUint64(0x8000000080008000), - constUint64(0x000000000000808B), - constUint64(0x0000000080000001), - constUint64(0x8000000080008081), - constUint64(0x8000000000008009), - constUint64(0x000000000000008A), - constUint64(0x0000000000000088), - constUint64(0x0000000080008009), - constUint64(0x000000008000000A), - constUint64(0x000000008000808B), - constUint64(0x800000000000008B), - constUint64(0x8000000000008089), - constUint64(0x8000000000008003), - constUint64(0x8000000000008002), - constUint64(0x8000000000000080), - constUint64(0x000000000000800A), - constUint64(0x800000008000000A), - constUint64(0x8000000080008081), - constUint64(0x8000000000008080), - constUint64(0x0000000080000001), - constUint64(0x8000000080008008), +var rc = [24]uints.U64{ + uints.NewU64(0x0000000000000001), + uints.NewU64(0x0000000000008082), + uints.NewU64(0x800000000000808A), + uints.NewU64(0x8000000080008000), + uints.NewU64(0x000000000000808B), + uints.NewU64(0x0000000080000001), + uints.NewU64(0x8000000080008081), + uints.NewU64(0x8000000000008009), + uints.NewU64(0x000000000000008A), + uints.NewU64(0x0000000000000088), + uints.NewU64(0x0000000080008009), + uints.NewU64(0x000000008000000A), + uints.NewU64(0x000000008000808B), + uints.NewU64(0x800000000000008B), + uints.NewU64(0x8000000000008089), + uints.NewU64(0x8000000000008003), + uints.NewU64(0x8000000000008002), + uints.NewU64(0x8000000000000080), + uints.NewU64(0x000000000000800A), + uints.NewU64(0x800000008000000A), + uints.NewU64(0x8000000080008081), + uints.NewU64(0x8000000000008080), + uints.NewU64(0x0000000080000001), + uints.NewU64(0x8000000080008008), } var rotc = [24]int{ 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, @@ -53,32 +54,34 @@ var piln = [24]int{ // vector. The input array must consist of 64-bit (unsigned) integers. The // returned array also contains 64-bit unsigned integers. func Permute(api frontend.API, a [25]frontend.Variable) [25]frontend.Variable { - var in [25]xuint64 - uapi := newUint64API(api) + var in [25]uints.U64 + uapi, err := uints.New[uints.U64](api) + if err != nil { + panic(err) // TODO: return error instead + } for i := range a { - in[i] = uapi.asUint64(a[i]) + in[i] = uapi.ValueOf(a[i]) } - res := permute(api, in) + res := permute(uapi, in) var out [25]frontend.Variable for i := range out { - out[i] = uapi.fromUint64(res[i]) + out[i] = uapi.ToValue(res[i]) } return out } -func permute(api frontend.API, st [25]xuint64) [25]xuint64 { - uapi := newUint64API(api) - var t xuint64 - var bc [5]xuint64 +func permute(uapi *uints.BinaryField[uints.U64], st [25]uints.U64) [25]uints.U64 { + var t uints.U64 + var bc [5]uints.U64 for r := 0; r < 24; r++ { // theta for i := 0; i < 5; i++ { - bc[i] = uapi.xor(st[i], st[i+5], st[i+10], st[i+15], st[i+20]) + bc[i] = uapi.Xor(st[i], st[i+5], st[i+10], st[i+15], st[i+20]) } for i := 0; i < 5; i++ { - t = uapi.xor(bc[(i+4)%5], uapi.lrot(bc[(i+1)%5], 1)) + t = uapi.Xor(bc[(i+4)%5], uapi.Lrot(bc[(i+1)%5], 1)) for j := 0; j < 25; j += 5 { - st[j+i] = uapi.xor(st[j+i], t) + st[j+i] = uapi.Xor(st[j+i], t) } } // rho pi @@ -86,7 +89,7 @@ func permute(api frontend.API, st [25]xuint64) [25]xuint64 { for i := 0; i < 24; i++ { j := piln[i] bc[0] = st[j] - st[j] = uapi.lrot(t, rotc[i]) + st[j] = uapi.Lrot(t, rotc[i]) t = bc[0] } @@ -96,11 +99,11 @@ func permute(api frontend.API, st [25]xuint64) [25]xuint64 { bc[i] = st[j+i] } for i := 0; i < 5; i++ { - st[j+i] = uapi.xor(st[j+i], uapi.and(uapi.not(bc[(i+1)%5]), bc[(i+2)%5])) + st[j+i] = uapi.Xor(st[j+i], uapi.And(uapi.Not(bc[(i+1)%5]), bc[(i+2)%5])) } } // iota - st[0] = uapi.xor(st[0], rc[r]) + st[0] = uapi.Xor(st[0], rc[r]) } return st } diff --git a/std/permutation/keccakf/uint64api.go b/std/permutation/keccakf/uint64api.go deleted file mode 100644 index c7c2246bb2..0000000000 --- a/std/permutation/keccakf/uint64api.go +++ /dev/null @@ -1,101 +0,0 @@ -package keccakf - -import ( - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/bits" -) - -// uint64api performs binary operations on xuint64 variables. In the -// future possibly using lookup tables. -// -// TODO: we could possibly optimise using hints if working over many inputs. For -// example, if we OR many bits, then the result is 0 if the sum of the bits is -// larger than 1. And AND is 1 if the sum of bits is the number of inputs. BUt -// this probably helps only if we have a lot of similar operations in a row -// (more than 4). We could probably unroll the whole permutation and expand all -// the formulas to see. But long term tables are still better. -type uint64api struct { - api frontend.API -} - -func newUint64API(api frontend.API) *uint64api { - return &uint64api{ - api: api, - } -} - -// varUint64 represents 64-bit unsigned integer. We use this type to ensure that -// we work over constrained bits. Do not initialize directly, use [wideBinaryOpsApi.asUint64]. -type xuint64 [64]frontend.Variable - -func constUint64(a uint64) xuint64 { - var res xuint64 - for i := 0; i < 64; i++ { - res[i] = (a >> i) & 1 - } - return res -} - -func (w *uint64api) asUint64(in frontend.Variable) xuint64 { - bits := bits.ToBinary(w.api, in, bits.WithNbDigits(64)) - var res xuint64 - copy(res[:], bits) - return res -} - -func (w *uint64api) fromUint64(in xuint64) frontend.Variable { - return bits.FromBinary(w.api, in[:], bits.WithUnconstrainedInputs()) -} - -func (w *uint64api) and(in ...xuint64) xuint64 { - var res xuint64 - for i := range res { - res[i] = 1 - } - for i := range res { - for _, v := range in { - res[i] = w.api.And(res[i], v[i]) - } - } - return res -} - -func (w *uint64api) xor(in ...xuint64) xuint64 { - var res xuint64 - for i := range res { - res[i] = 0 - } - for i := range res { - for _, v := range in { - res[i] = w.api.Xor(res[i], v[i]) - } - } - return res -} - -func (w *uint64api) lrot(in xuint64, shift int) xuint64 { - var res xuint64 - for i := range res { - res[i] = in[(i-shift+64)%64] - } - return res -} - -func (w *uint64api) not(in xuint64) xuint64 { - // TODO: it would be better to have separate method for it. If we have - // native API support, then in R1CS would be free (1-X) and in PLONK 1 - // constraint (1-X). But if we do XOR, then we always have a constraint with - // R1CS (not sure if 1-2 with PLONK). If we do 1-X ourselves, then compiler - // marks as binary which is 1-2 (R1CS-PLONK). - var res xuint64 - for i := range res { - res[i] = w.api.Xor(in[i], 1) - } - return res -} - -func (w *uint64api) assertEq(a, b xuint64) { - for i := range a { - w.api.AssertIsEqual(a[i], b[i]) - } -} diff --git a/std/permutation/keccakf/uint64api_test.go b/std/permutation/keccakf/uint64api_test.go deleted file mode 100644 index 0e39794695..0000000000 --- a/std/permutation/keccakf/uint64api_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package keccakf - -import ( - "testing" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/test" -) - -type lrotCirc struct { - In frontend.Variable - Shift int - Out frontend.Variable -} - -func (c *lrotCirc) Define(api frontend.API) error { - uapi := newUint64API(api) - in := uapi.asUint64(c.In) - out := uapi.asUint64(c.Out) - res := uapi.lrot(in, c.Shift) - uapi.assertEq(out, res) - return nil -} - -func TestLeftRotation(t *testing.T) { - assert := test.NewAssert(t) - // err := test.IsSolved(&lrotCirc{Shift: 2}, &lrotCirc{In: 6, Shift: 2, Out: 24}, ecc.BN254.ScalarField()) - // assert.NoError(err) - assert.ProverSucceeded(&lrotCirc{Shift: 2}, &lrotCirc{In: 6, Shift: 2, Out: 24}) -} diff --git a/std/permutation/sha2/sha2block.go b/std/permutation/sha2/sha2block.go new file mode 100644 index 0000000000..a3991230b3 --- /dev/null +++ b/std/permutation/sha2/sha2block.go @@ -0,0 +1,90 @@ +package sha2 + +import ( + "github.com/consensys/gnark/std/math/uints" +) + +var _K = uints.NewU32Array([]uint32{ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +}) + +func Permute(uapi *uints.BinaryField[uints.U32], currentHash [8]uints.U32, p [64]uints.U8) (newHash [8]uints.U32) { + var w [64]uints.U32 + + for i := 0; i < 16; i++ { + w[i] = uapi.PackMSB(p[4*i], p[4*i+1], p[4*i+2], p[4*i+3]) + } + + for i := 16; i < 64; i++ { + v1 := w[i-2] + t1 := uapi.Xor( + uapi.Lrot(v1, -17), + uapi.Lrot(v1, -19), + uapi.Rshift(v1, 10), + ) + v2 := w[i-15] + t2 := uapi.Xor( + uapi.Lrot(v2, -7), + uapi.Lrot(v2, -18), + uapi.Rshift(v2, 3), + ) + + w[i] = uapi.Add(t1, w[i-7], t2, w[i-16]) + } + + a, b, c, d, e, f, g, h := currentHash[0], currentHash[1], currentHash[2], currentHash[3], currentHash[4], currentHash[5], currentHash[6], currentHash[7] + + for i := 0; i < 64; i++ { + t1 := uapi.Add( + h, + uapi.Xor( + uapi.Lrot(e, -6), + uapi.Lrot(e, -11), + uapi.Lrot(e, -25)), + uapi.Xor( + uapi.And(e, f), + uapi.And( + uapi.Not(e), + g)), + _K[i], + w[i], + ) + t2 := uapi.Add( + uapi.Xor( + uapi.Lrot(a, -2), + uapi.Lrot(a, -13), + uapi.Lrot(a, -22)), + uapi.Xor( + uapi.And(a, b), + uapi.And(a, c), + uapi.And(b, c)), + ) + + h = g + g = f + f = e + e = uapi.Add(d, t1) + d = c + c = b + b = a + a = uapi.Add(t1, t2) + } + + currentHash[0] = uapi.Add(currentHash[0], a) + currentHash[1] = uapi.Add(currentHash[1], b) + currentHash[2] = uapi.Add(currentHash[2], c) + currentHash[3] = uapi.Add(currentHash[3], d) + currentHash[4] = uapi.Add(currentHash[4], e) + currentHash[5] = uapi.Add(currentHash[5], f) + currentHash[6] = uapi.Add(currentHash[6], g) + currentHash[7] = uapi.Add(currentHash[7], h) + + return currentHash +} diff --git a/std/permutation/sha2/sha2block_test.go b/std/permutation/sha2/sha2block_test.go new file mode 100644 index 0000000000..b634ab6624 --- /dev/null +++ b/std/permutation/sha2/sha2block_test.go @@ -0,0 +1,116 @@ +package sha2_test + +import ( + "math/bits" + "math/rand" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/std/permutation/sha2" + "github.com/consensys/gnark/test" +) + +var _K = []uint32{ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +} + +const ( + chunk = 64 +) + +type digest struct { + h [8]uint32 +} + +func blockGeneric(dig *digest, p []byte) { + var w [64]uint32 + h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] + for len(p) >= chunk { + // Can interlace the computation of w with the + // rounds below if needed for speed. + for i := 0; i < 16; i++ { + j := i * 4 + w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3]) + } + for i := 16; i < 64; i++ { + v1 := w[i-2] + t1 := (bits.RotateLeft32(v1, -17)) ^ (bits.RotateLeft32(v1, -19)) ^ (v1 >> 10) + v2 := w[i-15] + t2 := (bits.RotateLeft32(v2, -7)) ^ (bits.RotateLeft32(v2, -18)) ^ (v2 >> 3) + w[i] = t1 + w[i-7] + t2 + w[i-16] + } + + a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7 + + for i := 0; i < 64; i++ { + t1 := h + ((bits.RotateLeft32(e, -6)) ^ (bits.RotateLeft32(e, -11)) ^ (bits.RotateLeft32(e, -25))) + ((e & f) ^ (^e & g)) + _K[i] + w[i] + + t2 := ((bits.RotateLeft32(a, -2)) ^ (bits.RotateLeft32(a, -13)) ^ (bits.RotateLeft32(a, -22))) + ((a & b) ^ (a & c) ^ (b & c)) + + h, g, f, e, d, c, b, a = g, f, e, d+t1, c, b, a, t1+t2 + } + + h0 += a + h1 += b + h2 += c + h3 += d + h4 += e + h5 += f + h6 += g + h7 += h + + p = p[chunk:] + } + + dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] = h0, h1, h2, h3, h4, h5, h6, h7 +} + +type circuitBlock struct { + CurrentDig [8]uints.U32 + In [64]uints.U8 + Expected [8]uints.U32 +} + +func (c *circuitBlock) Define(api frontend.API) error { + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + res := sha2.Permute(uapi, c.CurrentDig, c.In) + for i := range c.Expected { + uapi.AssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestBlockGeneric(t *testing.T) { + assert := test.NewAssert(t) + s := rand.New(rand.NewSource(time.Now().Unix())) //nolint G404, test code + witness := circuitBlock{} + dig := digest{} + var in [chunk]byte + for i := range dig.h { + dig.h[i] = s.Uint32() + witness.CurrentDig[i] = uints.NewU32(dig.h[i]) + } + for i := range in { + in[i] = byte(s.Uint32() & 0xff) + witness.In[i] = uints.NewU8(in[i]) + } + blockGeneric(&dig, in[:]) + for i := range dig.h { + witness.Expected[i] = uints.NewU32(dig.h[i]) + } + err := test.IsSolved(&circuitBlock{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/rangecheck/rangecheck.go b/std/rangecheck/rangecheck.go new file mode 100644 index 0000000000..8aac734dcb --- /dev/null +++ b/std/rangecheck/rangecheck.go @@ -0,0 +1,37 @@ +// Package rangecheck implements range checking gadget +// +// This package chooses the most optimal path for performing range checks: +// - if the backend supports native range checking and the frontend exports the variables in the proprietary format by implementing [frontend.Rangechecker], then use it directly; +// - if the backend supports creating a commitment of variables by implementing [frontend.Committer], then we use the log-derivative variant [[Haböck22]] of the product argument as in [[BCG+18]] . [r1cs.NewBuilder] returns a builder which implements this interface; +// - lacking these, we perform binary decomposition of variable into bits. +// +// [BCG+18]: https://eprint.iacr.org/2018/380 +// [Haböck22]: https://eprint.iacr.org/2022/1530 +package rangecheck + +import ( + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +// only for documentation purposes. If we import the package then godoc knows +// how to refer to package r1cs and we get nice links in godoc. We import the +// package anyway in test. +var _ = r1cs.NewBuilder + +// New returns a new range checker depending on the frontend capabilities. +func New(api frontend.API) frontend.Rangechecker { + if rc, ok := api.(frontend.Rangechecker); ok { + return rc + } + if _, ok := api.(frontend.Committer); ok { + return newCommitRangechecker(api) + } + return plainChecker{api: api} +} + +// GetHints returns all hints used in this package +func GetHints() []solver.Hint { + return []solver.Hint{DecomposeHint} +} diff --git a/std/rangecheck/rangecheck_commit.go b/std/rangecheck/rangecheck_commit.go new file mode 100644 index 0000000000..6af6eb5be6 --- /dev/null +++ b/std/rangecheck/rangecheck_commit.go @@ -0,0 +1,175 @@ +package rangecheck + +import ( + "fmt" + "math" + "math/big" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/frontendtype" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/internal/logderivarg" +) + +type ctxCheckerKey struct{} + +func init() { + solver.RegisterHint(DecomposeHint) +} + +type checkedVariable struct { + v frontend.Variable + bits int +} + +type commitChecker struct { + collected []checkedVariable + closed bool +} + +func newCommitRangechecker(api frontend.API) *commitChecker { + kv, ok := api.Compiler().(kvstore.Store) + if !ok { + panic("builder should implement key-value store") + } + ch := kv.GetKeyValue(ctxCheckerKey{}) + if ch != nil { + if cht, ok := ch.(*commitChecker); ok { + return cht + } else { + panic("stored rangechecker is not valid") + } + } + cht := &commitChecker{} + kv.SetKeyValue(ctxCheckerKey{}, cht) + api.Compiler().Defer(cht.commit) + return cht +} + +func (c *commitChecker) Check(in frontend.Variable, bits int) { + if c.closed { + panic("checker already closed") + } + c.collected = append(c.collected, checkedVariable{v: in, bits: bits}) +} + +func (c *commitChecker) buildTable(nbTable int) []frontend.Variable { + tbl := make([]frontend.Variable, nbTable) + for i := 0; i < nbTable; i++ { + tbl[i] = i + } + return tbl +} + +func (c *commitChecker) commit(api frontend.API) error { + if c.closed { + return nil + } + defer func() { c.closed = true }() + if len(c.collected) == 0 { + return nil + } + baseLength := c.getOptimalBasewidth(api) + // decompose into smaller limbs + decomposed := make([]frontend.Variable, 0, len(c.collected)) + collected := make([]frontend.Variable, len(c.collected)) + base := new(big.Int).Lsh(big.NewInt(1), uint(baseLength)) + for i := range c.collected { + // collect all vars for commitment input + collected[i] = c.collected[i].v + // decompose value into limbs + nbLimbs := decompSize(c.collected[i].bits, baseLength) + limbs, err := api.Compiler().NewHint(DecomposeHint, int(nbLimbs), c.collected[i].bits, baseLength, c.collected[i].v) + if err != nil { + panic(fmt.Sprintf("decompose %v", err)) + } + // store all limbs for counting + decomposed = append(decomposed, limbs...) + // check that limbs are correct. We check the sizes of the limbs later + var composed frontend.Variable = 0 + for j := range limbs { + composed = api.Add(composed, api.Mul(limbs[j], new(big.Int).Exp(base, big.NewInt(int64(j)), nil))) + } + api.AssertIsEqual(composed, c.collected[i].v) + } + nbTable := 1 << baseLength + return logderivarg.Build(api, logderivarg.AsTable(c.buildTable(nbTable)), logderivarg.AsTable(decomposed)) +} + +func decompSize(varSize int, limbSize int) int { + return (varSize + limbSize - 1) / limbSize +} + +// DecomposeHint is a hint used for range checking with commitment. It +// decomposes large variables into chunks which can be individually range-check +// in the native range. +func DecomposeHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 3 { + return fmt.Errorf("input must be 3 elements") + } + if !inputs[0].IsUint64() || !inputs[1].IsUint64() { + return fmt.Errorf("first two inputs have to be uint64") + } + varSize := int(inputs[0].Int64()) + limbSize := int(inputs[1].Int64()) + val := inputs[2] + nbLimbs := decompSize(varSize, limbSize) + if len(outputs) != nbLimbs { + return fmt.Errorf("need %d outputs instead to decompose", nbLimbs) + } + base := new(big.Int).Lsh(big.NewInt(1), uint(limbSize)) + tmp := new(big.Int).Set(val) + for i := 0; i < len(outputs); i++ { + outputs[i].Mod(tmp, base) + tmp.Rsh(tmp, uint(limbSize)) + } + return nil +} + +func (c *commitChecker) getOptimalBasewidth(api frontend.API) int { + if ft, ok := api.(frontendtype.FrontendTyper); ok { + switch ft.FrontendType() { + case frontendtype.R1CS: + return optimalWidth(nbR1CSConstraints, c.collected) + case frontendtype.SCS: + return optimalWidth(nbPLONKConstraints, c.collected) + } + } + return optimalWidth(nbR1CSConstraints, c.collected) +} + +func optimalWidth(countFn func(baseLength int, collected []checkedVariable) int, collected []checkedVariable) int { + min := math.MaxInt64 + minVal := 0 + for j := 2; j < 18; j++ { + current := countFn(j, collected) + if current < min { + min = current + minVal = j + } + } + return minVal +} + +func nbR1CSConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := len(collected) // correctness of decomposition + nbRight := nbDecomposed // inverse per decomposed + nbleft := (1 << baseLength) // div per table + return nbleft + nbRight + eqs + 1 +} + +func nbPLONKConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := nbDecomposed // check correctness of every decomposition. this is nbDecomp adds + eq cost per collected + nbRight := 3 * nbDecomposed // denominator sub, inv and large sum per table entry + nbleft := 3 * (1 << baseLength) // denominator sub, div and large sum per table entry + return nbleft + nbRight + eqs + 1 // and the final assert +} diff --git a/std/rangecheck/rangecheck_plain.go b/std/rangecheck/rangecheck_plain.go new file mode 100644 index 0000000000..6f20418f2d --- /dev/null +++ b/std/rangecheck/rangecheck_plain.go @@ -0,0 +1,14 @@ +package rangecheck + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/bits" +) + +type plainChecker struct { + api frontend.API +} + +func (pl plainChecker) Check(v frontend.Variable, nbBits int) { + bits.ToBinary(pl.api, v, bits.WithNbDigits(nbBits)) +} diff --git a/std/rangecheck/rangecheck_test.go b/std/rangecheck/rangecheck_test.go new file mode 100644 index 0000000000..42a26827d8 --- /dev/null +++ b/std/rangecheck/rangecheck_test.go @@ -0,0 +1,46 @@ +package rangecheck + +import ( + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type CheckCircuit struct { + Vals []frontend.Variable + bits int +} + +func (c *CheckCircuit) Define(api frontend.API) error { + r := newCommitRangechecker(api) + for i := range c.Vals { + r.Check(c.Vals[i], c.bits) + } + return nil +} + +func TestCheck(t *testing.T) { + assert := test.NewAssert(t) + var err error + bits := 64 + nbVals := 100000 + bound := new(big.Int).Lsh(big.NewInt(1), uint(bits)) + vals := make([]frontend.Variable, nbVals) + for i := range vals { + vals[i], err = rand.Int(rand.Reader, bound) + if err != nil { + t.Fatal(err) + } + } + witness := CheckCircuit{Vals: vals, bits: bits} + circuit := CheckCircuit{Vals: make([]frontend.Variable, len(vals)), bits: bits} + err = test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + _, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithCompressThreshold(100)) + assert.NoError(err) +} diff --git a/std/selector/doc_map_test.go b/std/selector/doc_map_test.go new file mode 100644 index 0000000000..f70c353f61 --- /dev/null +++ b/std/selector/doc_map_test.go @@ -0,0 +1,79 @@ +package selector_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// MapCircuit is a minimal circuit using a selector map. +type MapCircuit struct { + QueryKey frontend.Variable + Keys [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + Values [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + ExpectedValue frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *MapCircuit) Define(api frontend.API) error { + result := selector.Map(api, c.QueryKey, c.Keys[:], c.Values[:]) + api.AssertIsEqual(result, c.ExpectedValue) + return nil +} + +// ExampleMap gives an example on how to use map selector. +func ExampleMap() { + circuit := MapCircuit{} + witness := MapCircuit{ + QueryKey: 55, + Keys: [10]frontend.Variable{0, 11, 22, 33, 44, 55, 66, 77, 88, 99}, + Values: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + ExpectedValue: 10, // element in values which corresponds to the position of 55 in keys + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/doc_mux_test.go b/std/selector/doc_mux_test.go new file mode 100644 index 0000000000..5a7b75ba36 --- /dev/null +++ b/std/selector/doc_mux_test.go @@ -0,0 +1,77 @@ +package selector_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// MuxCircuit is a minimal circuit using a selector mux. +type MuxCircuit struct { + Selector frontend.Variable + In [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + Expected frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *MuxCircuit) Define(api frontend.API) error { + result := selector.Mux(api, c.Selector, c.In[:]...) // Note Mux takes var-arg input, need to expand the input vector + api.AssertIsEqual(result, c.Expected) + return nil +} + +// ExampleMux gives an example on how to use mux selector. +func ExampleMux() { + circuit := MuxCircuit{} + witness := MuxCircuit{ + Selector: 5, + In: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + Expected: 10, // 5-th element in vector + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/doc_partition_test.go b/std/selector/doc_partition_test.go new file mode 100644 index 0000000000..2a9b4ea548 --- /dev/null +++ b/std/selector/doc_partition_test.go @@ -0,0 +1,80 @@ +package selector_test + +import "github.com/consensys/gnark/frontend" + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// adderCircuit adds first Count number of its input array In. +type adderCircuit struct { + Count frontend.Variable + In [10]frontend.Variable + ExpectedSum frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *adderCircuit) Define(api frontend.API) error { + selectedPart := selector.Partition(api, c.Count, false, c.In[:]) + sum := api.Add(selectedPart[0], selectedPart[1], selectedPart[2:]...) + api.AssertIsEqual(sum, c.ExpectedSum) + return nil +} + +// ExamplePartition gives an example on how to use selector.Partition to make a circuit that accepts a Count and an +// input array In, and then calculates the sum of first Count numbers of the input array. +func ExamplePartition() { + circuit := adderCircuit{} + witness := adderCircuit{ + Count: 6, + In: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + ExpectedSum: 30, + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/multiplexer.go b/std/selector/multiplexer.go new file mode 100644 index 0000000000..5f9b0d7114 --- /dev/null +++ b/std/selector/multiplexer.go @@ -0,0 +1,169 @@ +// Package selector provides a lookup table and map, based on linear scan. +// +// The native [frontend.API] provides 1- and 2-bit lookups through the interface +// methods Select and Lookup2. This package extends the lookups to +// arbitrary-sized vectors. The lookups can be performed using the index of the +// elements (function [Mux]) or using a key, for which the user needs to provide +// the slice of keys (function [Map]). +// +// The implementation uses linear scan over all inputs. +package selector + +import ( + "fmt" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/bits" + "math/big" + binary "math/bits" +) + +func init() { + // register hints + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in this package. This method is +// useful for registering all hints in the solver. +func GetHints() []solver.Hint { + return []solver.Hint{stepOutput, muxIndicators, mapIndicators} +} + +// Map is a key value associative array: the output will be values[i] such that +// keys[i] == queryKey. If keys does not contain queryKey, no proofs can be +// generated. If keys has more than one key that equals to queryKey, the output +// will be undefined, and the output could be any linear combination of all the +// corresponding values with that queryKey. +// +// In case keys and values do not have the same length, this function will +// panic. +func Map(api frontend.API, queryKey frontend.Variable, + keys []frontend.Variable, values []frontend.Variable) frontend.Variable { + // we don't need this check, but we added it to produce more informative errors + // and disallow len(keys) < len(values) which is supported by generateSelector. + if len(keys) != len(values) { + panic(fmt.Sprintf("The number of keys and values must be equal (%d != %d)", len(keys), len(values))) + } + return dotProduct(api, values, KeyDecoder(api, queryKey, keys)) +} + +// Mux is an n to 1 multiplexer: out = inputs[sel]. In other words, it selects +// exactly one of its inputs based on sel. The index of inputs starts from zero. +// +// sel needs to be between 0 and n - 1 (inclusive), where n is the number of +// inputs, otherwise the proof will fail. +func Mux(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable) frontend.Variable { + // we use BinaryMux when len(inputs) is a power of 2. + if binary.OnesCount(uint(len(inputs))) == 1 { + selBits := bits.ToBinary(api, sel, bits.WithNbDigits(binary.Len(uint(len(inputs)))-1)) + return BinaryMux(api, selBits, inputs) + } + return dotProduct(api, inputs, Decoder(api, len(inputs), sel)) +} + +// KeyDecoder is a decoder that associates keys to its output wires. It outputs +// 1 on the wire that is associated to a key that equals to queryKey. In other +// words: +// +// if keys[i] == queryKey +// out[i] = 1 +// else +// out[i] = 0 +// +// If keys has more than one key that equals to queryKey, the output is +// undefined. However, the output is guaranteed to be zero for the wires that +// are associated with a key which is not equal to queryKey. +func KeyDecoder(api frontend.API, queryKey frontend.Variable, keys []frontend.Variable) []frontend.Variable { + return generateDecoder(api, false, 0, queryKey, keys) +} + +// Decoder is a decoder with n outputs. It outputs 1 on the wire with index sel, +// and 0 otherwise. Indices start from zero. In other words: +// +// if i == sel +// out[i] = 1 +// else +// out[i] = 0 +// +// sel needs to be between 0 and n - 1 (inclusive) otherwise no proof can be +// generated. +func Decoder(api frontend.API, n int, sel frontend.Variable) []frontend.Variable { + return generateDecoder(api, true, n, sel, nil) +} + +// generateDecoder generates a circuit for a decoder which indicates the +// selected index. If sequential is true, an ordinary decoder of size n is +// generated, and keys are ignored. If sequential is false, a key based decoder +// is generated, and len(keys) is used to determine the size of the output. n +// will be ignored in this case. +func generateDecoder(api frontend.API, sequential bool, n int, sel frontend.Variable, + keys []frontend.Variable) []frontend.Variable { + + var indicators []frontend.Variable + var err error + if sequential { + indicators, err = api.Compiler().NewHint(muxIndicators, n, sel) + } else { + indicators, err = api.Compiler().NewHint(mapIndicators, len(keys), append(keys, sel)...) + } + if err != nil { + panic(fmt.Sprintf("error in calling Mux/Map hint: %v", err)) + } + + indicatorsSum := frontend.Variable(0) + for i := 0; i < len(indicators); i++ { + // Check that all indicators for inputs that are not selected, are zero. + if sequential { + // indicators[i] * (sel - i) == 0 + api.AssertIsEqual(api.Mul(indicators[i], api.Sub(sel, i)), 0) + } else { + // indicators[i] * (sel - keys[i]) == 0 + api.AssertIsEqual(api.Mul(indicators[i], api.Sub(sel, keys[i])), 0) + } + indicatorsSum = api.Add(indicatorsSum, indicators[i]) + } + // We need to check that the indicator of the selected input is exactly 1. We + // use a sum constraint, because usually it is cheap. + api.AssertIsEqual(indicatorsSum, 1) + return indicators +} + +func dotProduct(api frontend.API, a, b []frontend.Variable) frontend.Variable { + out := frontend.Variable(0) + for i := 0; i < len(a); i++ { + // out += indicators[i] * values[i] + out = api.MulAcc(out, a[i], b[i]) + } + return out +} + +// muxIndicators is a hint function used within [Mux] function. It must be +// provided to the prover when circuit uses it. +func muxIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + sel := inputs[0] + for i := 0; i < len(results); i++ { + // i is an int which can be int32 or int64. We convert i to int64 then to + // bigInt, which is safe. We should not convert sel to int64. + if sel.Cmp(big.NewInt(int64(i))) == 0 { + results[i].SetUint64(1) + } else { + results[i].SetUint64(0) + } + } + return nil +} + +// mapIndicators is a hint function used within [Map] function. It must be +// provided to the prover when circuit uses it. +func mapIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + key := inputs[len(inputs)-1] + // We must make sure that we are initializing all elements of results + for i := 0; i < len(results); i++ { + if key.Cmp(inputs[i]) == 0 { + results[i].SetUint64(1) + } else { + results[i].SetUint64(0) + } + } + return nil +} diff --git a/std/selector/multiplexer_test.go b/std/selector/multiplexer_test.go new file mode 100644 index 0000000000..2c01f08a01 --- /dev/null +++ b/std/selector/multiplexer_test.go @@ -0,0 +1,240 @@ +package selector_test + +import ( + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend/cs/r1cs" + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/selector" + "github.com/consensys/gnark/test" +) + +type muxCircuit struct { + SEL frontend.Variable + I0, I1, I2, I3, I4 frontend.Variable + OUT frontend.Variable +} + +func (c *muxCircuit) Define(api frontend.API) error { + + out := selector.Mux(api, c.SEL, c.I0, c.I1, c.I2, c.I3, c.I4) + + api.AssertIsEqual(out, c.OUT) + + return nil +} + +// The output of this circuit is ignored and the only way its proof can fail is by providing invalid inputs. +type ignoredOutputMuxCircuit struct { + SEL frontend.Variable + I0, I1, I2 frontend.Variable +} + +func (c *ignoredOutputMuxCircuit) Define(api frontend.API) error { + // We ignore the output + _ = selector.Mux(api, c.SEL, c.I0, c.I1, c.I2) + + return nil +} + +type mux2to1Circuit struct { + SEL frontend.Variable + I0, I1 frontend.Variable + OUT frontend.Variable +} + +func (c *mux2to1Circuit) Define(api frontend.API) error { + // We ignore the output + out := selector.Mux(api, c.SEL, c.I0, c.I1) + api.AssertIsEqual(out, c.OUT) + return nil +} + +type mux4to1Circuit struct { + SEL frontend.Variable + In [4]frontend.Variable + OUT frontend.Variable +} + +func (c *mux4to1Circuit) Define(api frontend.API) error { + out := selector.Mux(api, c.SEL, c.In[:]...) + api.AssertIsEqual(out, c.OUT) + return nil +} + +func TestMux(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 2, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 12}) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 0, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 10}) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 4, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}) + + // Failures + assert.ProverFailed(&muxCircuit{}, &muxCircuit{SEL: 5, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}) + + assert.ProverFailed(&muxCircuit{}, &muxCircuit{SEL: 0, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 21}) + + // Ignoring the circuit's output: + assert.ProverSucceeded(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 0, I0: 0, I1: 1, I2: 2}) + + assert.ProverSucceeded(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 2, I0: 0, I1: 1, I2: 2}) + + // Failures + assert.ProverFailed(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 3, I0: 0, I1: 1, I2: 2}) + + assert.ProverFailed(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: -1, I0: 0, I1: 1, I2: 2}) + + // 2 to 1 mux + assert.ProverSucceeded(&mux2to1Circuit{}, &mux2to1Circuit{SEL: 1, I0: 10, I1: 20, OUT: 20}) + + assert.ProverSucceeded(&mux2to1Circuit{}, &mux2to1Circuit{SEL: 0, I0: 10, I1: 20, OUT: 10}) + + assert.ProverFailed(&mux2to1Circuit{}, &mux2to1Circuit{SEL: 2, I0: 10, I1: 20, OUT: 20}) + + // 4 to 1 mux + assert.ProverSucceeded(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 3, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 44, + }) + + assert.ProverSucceeded(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 1, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 22, + }) + + assert.ProverSucceeded(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 0, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 11, + }) + + assert.ProverFailed(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 4, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 44, + }) + + cs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &mux4to1Circuit{}) + // (4 - 1) + (2 + 1) + 1 == 7 + assert.Equal(7, cs.GetNbConstraints()) +} + +// Map tests: +type mapCircuit struct { + SEL frontend.Variable + K0, K1, K2, K3 frontend.Variable + V0, V1, V2, V3 frontend.Variable + OUT frontend.Variable +} + +func (c *mapCircuit) Define(api frontend.API) error { + + out := selector.Map(api, c.SEL, + []frontend.Variable{c.K0, c.K1, c.K2, c.K3}, + []frontend.Variable{c.V0, c.V1, c.V2, c.V3}) + + api.AssertIsEqual(out, c.OUT) + + return nil +} + +type ignoredOutputMapCircuit struct { + SEL frontend.Variable + K0, K1 frontend.Variable + V0, V1 frontend.Variable +} + +func (c *ignoredOutputMapCircuit) Define(api frontend.API) error { + + _ = selector.Map(api, c.SEL, + []frontend.Variable{c.K0, c.K1}, + []frontend.Variable{c.V0, c.V1}) + + return nil +} + +func TestMap(t *testing.T) { + assert := test.NewAssert(t) + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 100, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 0, + }) + + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 222, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 2, + }) + + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + // Duplicated key, success: + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 222, K1: 222, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + // Duplicated key, UNDEFINED behavior: (with our hint implementation it fails) + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 100, K1: 111, K2: 333, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 77, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 111, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 2, + }) + + // Ignoring the circuit's output: + assert.ProverSucceeded(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 5, + K0: 5, K1: 7, + V0: 10, V1: 11, + }) + + assert.ProverFailed(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 5, + K0: 5, K1: 5, + V0: 10, V1: 11, + }) + + assert.ProverFailed(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 6, + K0: 5, K1: 7, + V0: 10, V1: 11, + }) + +} diff --git a/std/selector/mux.go b/std/selector/mux.go new file mode 100644 index 0000000000..d437430353 --- /dev/null +++ b/std/selector/mux.go @@ -0,0 +1,45 @@ +package selector + +import ( + "fmt" + "github.com/consensys/gnark/frontend" +) + +// BinaryMux is a 2^k to 1 multiplexer which uses a binary selector. selBits are +// the selector bits, and the input at the index equal to the binary number +// represented by the selector bits will be selected. More precisely the output +// will be: +// +// inputs[selBits[0]+selBits[1]*(1<<1)+selBits[2]*(1<<2)+...] +// +// len(inputs) must be 2^len(selBits). +func BinaryMux(api frontend.API, selBits, inputs []frontend.Variable) frontend.Variable { + if len(inputs) != 1<= len(inputs) { + return binaryMuxRecursive(api, nextSelBits, inputs) + } + + left := binaryMuxRecursive(api, nextSelBits, inputs[:pivot]) + right := binaryMuxRecursive(api, nextSelBits, inputs[pivot:]) + return api.Add(left, api.Mul(msb, api.Sub(right, left))) +} diff --git a/std/selector/mux_test.go b/std/selector/mux_test.go new file mode 100644 index 0000000000..f590c23f02 --- /dev/null +++ b/std/selector/mux_test.go @@ -0,0 +1,96 @@ +package selector + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + "testing" +) + +type binaryMuxCircuit struct { + Sel [4]frontend.Variable + In [10]frontend.Variable + Out frontend.Variable +} + +func (c *binaryMuxCircuit) Define(api frontend.API) error { + + out := binaryMuxRecursive(api, c.Sel[:], c.In[:]) + + api.AssertIsEqual(out, c.Out) + + return nil +} + +type binary7to1MuxCircuit struct { + Sel [3]frontend.Variable + In [7]frontend.Variable + Out frontend.Variable +} + +func (c *binary7to1MuxCircuit) Define(api frontend.API) error { + + out := binaryMuxRecursive(api, c.Sel[:], c.In[:]) + + api.AssertIsEqual(out, c.Out) + + return nil +} + +func Test_binaryMuxRecursive(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 0, 1, 0}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 144, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 0, 0, 0}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 100, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{1, 0, 0, 1}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 199, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 1, 0, 0}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 122, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 0, 0, 1}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 188, + }) + + // 7 to 1 + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{0, 0, 1}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 0, + }) + + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{0, 1, 1}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 1, + }) + + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{0, 0, 0}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 5, + }) + + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{1, 0, 1}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 9, + }) +} diff --git a/std/selector/slice.go b/std/selector/slice.go new file mode 100644 index 0000000000..b21be6059a --- /dev/null +++ b/std/selector/slice.go @@ -0,0 +1,105 @@ +package selector + +import ( + "fmt" + "github.com/consensys/gnark/frontend" + "math/big" +) + +// Slice selects a slice of the input array at indices [start, end), and zeroes the array at other +// indices. More precisely, for each i we have: +// +// if i >= start and i < end +// out[i] = input[i] +// else +// out[i] = 0 +// +// We must have start >= 0 and end <= len(input), otherwise a proof cannot be generated. +func Slice(api frontend.API, start, end frontend.Variable, input []frontend.Variable) []frontend.Variable { + // it appears that this is the most efficient implementation. There is also another implementation + // which creates the mask by adding two stepMask outputs, however that would not work correctly when + // end < start. + out := Partition(api, end, false, input) + out = Partition(api, start, true, out) + return out +} + +// Partition selects left or right side of the input array, with respect to the pivotPosition. +// More precisely when rightSide is false, for each i we have: +// +// if i < pivotPosition +// out[i] = input[i] +// else +// out[i] = 0 +// +// and when rightSide is true, for each i we have: +// +// if i >= pivotPosition +// out[i] = input[i] +// else +// out[i] = 0 +// +// We must have pivotPosition >= 0 and pivotPosition <= len(input), otherwise a proof cannot be generated. +func Partition(api frontend.API, pivotPosition frontend.Variable, rightSide bool, + input []frontend.Variable) (out []frontend.Variable) { + out = make([]frontend.Variable, len(input)) + var mask []frontend.Variable + // we create a bit mask to multiply with the input. + if rightSide { + mask = stepMask(api, len(input), pivotPosition, 0, 1) + } else { + mask = stepMask(api, len(input), pivotPosition, 1, 0) + } + for i := 0; i < len(out); i++ { + out[i] = api.Mul(mask[i], input[i]) + } + return +} + +// stepMask generates a step like function into an output array of a given length. +// The output is an array of length outputLen, +// such that its first stepPosition elements are equal to startValue and the remaining elements are equal to +// endValue. Note that outputLen cannot be a circuit variable. +// +// We must have stepPosition >= 0 and stepPosition <= outputLen, otherwise a proof cannot be generated. +// This function panics when outputLen is less than 2. +func stepMask(api frontend.API, outputLen int, + stepPosition, startValue, endValue frontend.Variable) []frontend.Variable { + if outputLen < 2 { + panic("the output len of StepMask must be >= 2") + } + // get the output as a hint + out, err := api.Compiler().NewHint(stepOutput, outputLen, stepPosition, startValue, endValue) + if err != nil { + panic(fmt.Sprintf("error in calling StepMask hint: %v", err)) + } + + // add the boundary constraints: + // (out[0] - startValue) * stepPosition == 0 + api.AssertIsEqual(api.Mul(api.Sub(out[0], startValue), stepPosition), 0) + // (out[len(out)-1] - endValue) * (len(out) - stepPosition) == 0 + api.AssertIsEqual(api.Mul(api.Sub(out[len(out)-1], endValue), api.Sub(len(out), stepPosition)), 0) + + // add constraints for the correct form of a step function that steps at the stepPosition + for i := 1; i < len(out); i++ { + // (out[i] - out[i-1]) * (i - stepPosition) == 0 + api.AssertIsEqual(api.Mul(api.Sub(out[i], out[i-1]), api.Sub(i, stepPosition)), 0) + } + return out +} + +// stepOutput is a hint function used within [StepMask] function. It must be +// provided to the prover when circuit uses it. +func stepOutput(_ *big.Int, inputs, results []*big.Int) error { + stepPos := inputs[0] + startValue := inputs[1] + endValue := inputs[2] + for i := 0; i < len(results); i++ { + if i < int(stepPos.Int64()) { + results[i].Set(startValue) + } else { + results[i].Set(endValue) + } + } + return nil +} diff --git a/std/selector/slice_test.go b/std/selector/slice_test.go new file mode 100644 index 0000000000..35715c886e --- /dev/null +++ b/std/selector/slice_test.go @@ -0,0 +1,205 @@ +package selector_test + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/selector" + "github.com/consensys/gnark/test" + "testing" +) + +type partitionerCircuit struct { + Pivot frontend.Variable `gnark:",public"` + In [6]frontend.Variable `gnark:",public"` + WantLeft [6]frontend.Variable `gnark:",public"` + WantRight [6]frontend.Variable `gnark:",public"` +} + +func (c *partitionerCircuit) Define(api frontend.API) error { + gotLeft := selector.Partition(api, c.Pivot, false, c.In[:]) + for i, want := range c.WantLeft { + api.AssertIsEqual(gotLeft[i], want) + } + + gotRight := selector.Partition(api, c.Pivot, true, c.In[:]) + for i, want := range c.WantRight { + api.AssertIsEqual(gotRight[i], want) + } + + return nil +} + +type ignoredOutputPartitionerCircuit struct { + Pivot frontend.Variable `gnark:",public"` + In [2]frontend.Variable `gnark:",public"` +} + +func (c *ignoredOutputPartitionerCircuit) Define(api frontend.API) error { + _ = selector.Partition(api, c.Pivot, false, c.In[:]) + _ = selector.Partition(api, c.Pivot, true, c.In[:]) + return nil +} + +func TestPartition(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 3, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 0, 0, 0}, + WantRight: [6]frontend.Variable{0, 0, 0, 40, 50, 60}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 1, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 0, 0, 0, 0, 0}, + WantRight: [6]frontend.Variable{0, 20, 30, 40, 50, 60}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 5, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 50, 0}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 60}, + }) + + assert.ProverFailed(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 5, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 0, 0}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 6, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 0, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + WantRight: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + }) + + // Pivot is outside and the prover fails: + assert.ProverFailed(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 7, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + }) + + assert.ProverFailed(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: -1, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + WantRight: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + }) + + // tests by ignoring the output: + assert.ProverSucceeded(&ignoredOutputPartitionerCircuit{}, &ignoredOutputPartitionerCircuit{ + Pivot: 1, + In: [2]frontend.Variable{10, 20}, + }) + + assert.ProverFailed(&ignoredOutputPartitionerCircuit{}, &ignoredOutputPartitionerCircuit{ + Pivot: -1, + In: [2]frontend.Variable{10, 20}, + }) + + assert.ProverFailed(&ignoredOutputPartitionerCircuit{}, &ignoredOutputPartitionerCircuit{ + Pivot: 3, + In: [2]frontend.Variable{10, 20}, + }) +} + +type slicerCircuit struct { + Start, End frontend.Variable `gnark:",public"` + In [7]frontend.Variable `gnark:",public"` + WantSlice [7]frontend.Variable `gnark:",public"` +} + +func (c *slicerCircuit) Define(api frontend.API) error { + gotSlice := selector.Slice(api, c.Start, c.End, c.In[:]) + for i, want := range c.WantSlice { + api.AssertIsEqual(gotSlice[i], want) + } + return nil +} + +func TestSlice(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 2, + End: 5, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 2, 3, 4, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 4, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 3, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 1, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 6, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 4, 5, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 7, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 4, 5, 6}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 0, + End: 2, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 1, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 0, + End: 7, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + }) + + assert.ProverFailed(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 8, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 4, 5, 6}, + }) + + assert.ProverFailed(&slicerCircuit{}, &slicerCircuit{ + Start: -1, + End: 2, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 1, 0, 0, 0, 0, 0}, + }) +} diff --git a/std/signature/ecdsa/ecdsa.go b/std/signature/ecdsa/ecdsa.go index 5279f69930..cac32d7db1 100644 --- a/std/signature/ecdsa/ecdsa.go +++ b/std/signature/ecdsa/ecdsa.go @@ -1,7 +1,7 @@ /* Package ecdsa implements ECDSA signature verification over any elliptic curve. -The package depends on the [weierstrass] package for elliptic curve group +The package depends on the [emulated/sw_emulated] package for elliptic curve group operations using non-native arithmetic. Thus we can verify ECDSA signatures over any curve. The cost for a single secp256k1 signature verification is approximately 4M constraints in R1CS and 10M constraints in PLONKish. @@ -15,7 +15,7 @@ package ecdsa import ( "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/weierstrass" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" ) @@ -25,14 +25,14 @@ type Signature[Scalar emulated.FieldParams] struct { } // PublicKey represents the public key to verify the signature for. -type PublicKey[Base, Scalar emulated.FieldParams] weierstrass.AffinePoint[Base] +type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base] // Verify asserts that the signature sig verifies for the message msg and public // key pk. The curve parameters params define the elliptic curve. // // We assume that the message msg is already hashed to the scalar field. -func (pk PublicKey[T, S]) Verify(api frontend.API, params weierstrass.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { - cr, err := weierstrass.New[T, S](api, params) +func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { + cr, err := sw_emulated.New[T, S](api, params) if err != nil { // TODO: softer handling. panic(err) @@ -45,14 +45,13 @@ func (pk PublicKey[T, S]) Verify(api frontend.API, params weierstrass.CurveParam if err != nil { panic(err) } - pkpt := weierstrass.AffinePoint[T](pk) + pkpt := sw_emulated.AffinePoint[T](pk) sInv := scalarApi.Inverse(&sig.S) msInv := scalarApi.MulMod(msg, sInv) rsInv := scalarApi.MulMod(&sig.R, sInv) - qa := cr.ScalarMul(cr.Generator(), msInv) - qb := cr.ScalarMul(&pkpt, rsInv) - q := cr.Add(qa, qb) + // q = [rsInv]pkpt + [msInv]g + q := cr.JointScalarMulBase(&pkpt, rsInv, msInv) qx := baseApi.Reduce(&q.X) qxBits := baseApi.ToBits(qx) rbits := scalarApi.ToBits(&sig.R) diff --git a/std/signature/ecdsa/ecdsa_secpr_test.go b/std/signature/ecdsa/ecdsa_secpr_test.go new file mode 100644 index 0000000000..139865e48b --- /dev/null +++ b/std/signature/ecdsa/ecdsa_secpr_test.go @@ -0,0 +1,111 @@ +package ecdsa + +import ( + cryptoecdsa "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +func TestEcdsaP256PreHashed(t *testing.T) { + + // generate parameters + privKey, _ := cryptoecdsa.GenerateKey(elliptic.P256(), rand.Reader) + publicKey := privKey.PublicKey + + // sign + msg := []byte("testing ECDSA (pre-hashed)") + msgHash := sha256.Sum256(msg) + sigBin, _ := privKey.Sign(rand.Reader, msgHash[:], nil) + + // check that the signature is correct + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sigBin) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + panic("invalid sig") + } + flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) + if !flag { + t.Errorf("can't verify signature") + } + + circuit := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} + witness := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{ + Sig: Signature[emulated.P256Fr]{ + R: emulated.ValueOf[emulated.P256Fr](r), + S: emulated.ValueOf[emulated.P256Fr](s), + }, + Msg: emulated.ValueOf[emulated.P256Fr](msgHash[:]), + Pub: PublicKey[emulated.P256Fp, emulated.P256Fr]{ + X: emulated.ValueOf[emulated.P256Fp](privKey.PublicKey.X), + Y: emulated.ValueOf[emulated.P256Fp](privKey.PublicKey.Y), + }, + } + assert := test.NewAssert(t) + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +func TestEcdsaP384PreHashed(t *testing.T) { + + // generate parameters + privKey, _ := cryptoecdsa.GenerateKey(elliptic.P384(), rand.Reader) + publicKey := privKey.PublicKey + + // sign + msg := []byte("testing ECDSA (pre-hashed)") + msgHash := sha512.Sum384(msg) + sigBin, _ := privKey.Sign(rand.Reader, msgHash[:], nil) + + // check that the signature is correct + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sigBin) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + panic("invalid sig") + } + flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) + if !flag { + t.Errorf("can't verify signature") + } + + circuit := EcdsaCircuit[emulated.P384Fp, emulated.P384Fr]{} + witness := EcdsaCircuit[emulated.P384Fp, emulated.P384Fr]{ + Sig: Signature[emulated.P384Fr]{ + R: emulated.ValueOf[emulated.P384Fr](r), + S: emulated.ValueOf[emulated.P384Fr](s), + }, + Msg: emulated.ValueOf[emulated.P384Fr](msgHash[:]), + Pub: PublicKey[emulated.P384Fp, emulated.P384Fr]{ + X: emulated.ValueOf[emulated.P384Fp](privKey.PublicKey.X), + Y: emulated.ValueOf[emulated.P384Fp](privKey.PublicKey.Y), + }, + } + assert := test.NewAssert(t) + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} diff --git a/std/signature/ecdsa/ecdsa_test.go b/std/signature/ecdsa/ecdsa_test.go index 711ec69554..57ff1a4406 100644 --- a/std/signature/ecdsa/ecdsa_test.go +++ b/std/signature/ecdsa/ecdsa_test.go @@ -9,7 +9,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/secp256k1/ecdsa" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/weierstrass" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/test" ) @@ -21,7 +21,7 @@ type EcdsaCircuit[T, S emulated.FieldParams] struct { } func (c *EcdsaCircuit[T, S]) Define(api frontend.API) error { - c.Pub.Verify(api, weierstrass.GetCurveParams[T](), &c.Msg, &c.Sig) + c.Pub.Verify(api, sw_emulated.GetCurveParams[T](), &c.Msg, &c.Sig) return nil } @@ -134,7 +134,7 @@ func ExamplePublicKey_Verify() { Y: emulated.ValueOf[emulated.Secp256k1Fp](puby), } // signature verification assertion is done in-circuit - Pub.Verify(api, weierstrass.GetCurveParams[emulated.Secp256k1Fp](), &Msg, &Sig) + Pub.Verify(api, sw_emulated.GetCurveParams[emulated.Secp256k1Fp](), &Msg, &Sig) } // Example how to create a valid signature for secp256k1 diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index e285f00496..b0841de4ad 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -24,7 +24,7 @@ import ( "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/algebra/native/twistededwards" tedwards "github.com/consensys/gnark-crypto/ecc/twistededwards" @@ -55,7 +55,7 @@ type Signature struct { // Verify verifies an eddsa signature using MiMC hash function // cf https://en.wikipedia.org/wiki/EdDSA -func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.Hash) error { +func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.FieldHasher) error { // compute H(R, A, M) hash.Write(sig.R.X) diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 657d337010..762bfb5ef0 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -27,7 +27,7 @@ import ( "github.com/consensys/gnark-crypto/signature/eddsa" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) diff --git a/test/assert.go b/test/assert.go index b6a077d2a1..10d1ebf977 100644 --- a/test/assert.go +++ b/test/assert.go @@ -17,8 +17,10 @@ limitations under the License. package test import ( + "bytes" "errors" "fmt" + "io" "reflect" "strings" "testing" @@ -35,6 +37,7 @@ import ( "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/frontend/schema" + gnarkio "github.com/consensys/gnark/io" "github.com/stretchr/testify/require" ) @@ -44,6 +47,10 @@ var ( ErrInvalidWitnessVerified = errors.New("invalid witness resulted in a valid proof") ) +// SerializationThreshold is the number of constraints above which we don't +// do a systematic round-trip serialization check for the proving and verifying keys. +const SerializationThreshold = 1000 + // Assert is a helper to test circuits type Assert struct { t *testing.T @@ -149,6 +156,16 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment case backend.GROTH16: pk, vk, err := groth16.Setup(ccs) checkError(err) + if ccs.GetNbConstraints() <= SerializationThreshold { + pkReconstructed := groth16.NewProvingKey(curve) + roundTripCheck(assert.t, pk, pkReconstructed) + pkReconstructed = groth16.NewProvingKey(curve) + roundTripCheckRaw(assert.t, pk, pkReconstructed) + vkReconstructed := groth16.NewVerifyingKey(curve) + roundTripCheck(assert.t, vk, vkReconstructed) + vkReconstructed = groth16.NewVerifyingKey(curve) + roundTripCheckRaw(assert.t, vk, vkReconstructed) + } // ensure prove / verify works well with valid witnesses @@ -158,19 +175,35 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment err = groth16.Verify(proof, vk, validPublicWitness) checkError(err) + if opt.solidity && curve == ecc.BN254 && vk.NbPublicWitness() > 0 { + // check that the proof can be verified by gnark-solidity-checker + assert.solidityVerification(b, vk, proof, validPublicWitness) + } + case backend.PLONK: srs, err := NewKZGSRS(ccs) checkError(err) pk, vk, err := plonk.Setup(ccs, srs) checkError(err) + if ccs.GetNbConstraints() <= SerializationThreshold { + pkReconstructed := plonk.NewProvingKey(curve) + roundTripCheck(assert.t, pk, pkReconstructed) + vkReconstructed := plonk.NewVerifyingKey(curve) + roundTripCheck(assert.t, vk, vkReconstructed) + } - correctProof, err := plonk.Prove(ccs, pk, validWitness, opt.proverOpts...) + proof, err := plonk.Prove(ccs, pk, validWitness, opt.proverOpts...) checkError(err) - err = plonk.Verify(correctProof, vk, validPublicWitness) + err = plonk.Verify(proof, vk, validPublicWitness) checkError(err) + if opt.solidity && curve == ecc.BN254 { + // check that the proof can be verified by gnark-solidity-checker + assert.solidityVerification(b, vk, proof, validPublicWitness) + } + case backend.PLONKFRI: pk, vk, err := plonkfri.Setup(ccs) checkError(err) @@ -208,15 +241,11 @@ func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment f opt := assert.options(opts...) - popts := append(opt.proverOpts, backend.IgnoreSolverError()) - for _, curve := range opt.curves { // parse assignment invalidWitness, err := frontend.NewWitness(invalidAssignment, curve.ScalarField()) assert.NoError(err, "can't parse invalid assignment") - invalidPublicWitness, err := frontend.NewWitness(invalidAssignment, curve.ScalarField(), frontend.PublicOnly()) - assert.NoError(err, "can't parse invalid assignment") for _, b := range opt.backends { curve := curve @@ -235,42 +264,8 @@ func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment f mustError(err) assert.t.Parallel() - err = ccs.IsSolved(invalidPublicWitness) + err = ccs.IsSolved(invalidWitness) mustError(err) - - switch b { - case backend.GROTH16: - pk, vk, err := groth16.Setup(ccs) - checkError(err) - - proof, _ := groth16.Prove(ccs, pk, invalidWitness, popts...) - - err = groth16.Verify(proof, vk, invalidPublicWitness) - mustError(err) - - case backend.PLONK: - srs, err := NewKZGSRS(ccs) - checkError(err) - - pk, vk, err := plonk.Setup(ccs, srs) - checkError(err) - - incorrectProof, _ := plonk.Prove(ccs, pk, invalidWitness, popts...) - err = plonk.Verify(incorrectProof, vk, invalidPublicWitness) - mustError(err) - - case backend.PLONKFRI: - - pk, vk, err := plonkfri.Setup(ccs) - checkError(err) - - incorrectProof, _ := plonkfri.Prove(ccs, pk, invalidWitness, popts...) - err = plonkfri.Verify(incorrectProof, vk, invalidPublicWitness) - mustError(err) - - default: - panic("backend not implemented") - } }, curve.String(), b.String()) } } @@ -306,7 +301,7 @@ func (assert *Assert) solvingSucceeded(circuit frontend.Circuit, validAssignment err = IsSolved(circuit, validAssignment, curve.ScalarField()) checkError(err) - err = ccs.IsSolved(validWitness, opt.proverOpts...) + err = ccs.IsSolved(validWitness, opt.solverOpts...) checkError(err) } @@ -352,7 +347,7 @@ func (assert *Assert) solvingFailed(circuit frontend.Circuit, invalidAssignment err = IsSolved(circuit, invalidAssignment, curve.ScalarField()) mustError(err) - err = ccs.IsSolved(invalidWitness, opt.proverOpts...) + err = ccs.IsSolved(invalidWitness, opt.solverOpts...) mustError(err) } @@ -405,8 +400,22 @@ func (assert *Assert) fuzzer(fuzzer filler, circuit, w frontend.Circuit, b backe errConsts := IsSolved(circuit, w, curve.ScalarField(), SetAllVariablesAsConstants()) if (errVars == nil) != (errConsts == nil) { + w, err := frontend.NewWitness(w, curve.ScalarField()) + if err != nil { + panic(err) + } + s, err := frontend.NewSchema(circuit) + if err != nil { + panic(err) + } + bb, err := w.ToJSON(s) + if err != nil { + panic(err) + } + assert.Log("errVars", errVars) assert.Log("errConsts", errConsts) + assert.Log("fuzzer witness", string(bb)) assert.FailNow("solving circuit with values as constants vs non-constants mismatched result") } @@ -578,3 +587,45 @@ func (assert *Assert) marshalWitnessJSON(w witness.Witness, s *schema.Schema, cu witnessMatch := reflect.DeepEqual(w, witness) assert.True(witnessMatch, "round trip marshaling failed") } + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} + +func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteRawTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/test/assert_solidity.go b/test/assert_solidity.go new file mode 100644 index 0000000000..1f027ac2b7 --- /dev/null +++ b/test/assert_solidity.go @@ -0,0 +1,95 @@ +package test + +import ( + "bytes" + "encoding/hex" + "io" + "os" + "os/exec" + "path/filepath" + "strconv" + + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + plonk_bn254 "github.com/consensys/gnark/backend/plonk/bn254" + "github.com/consensys/gnark/backend/witness" +) + +// solidityVerification checks that the exported solidity contract can verify the proof +// and that the proof is valid. +// It uses gnark-solidity-checker see test.WithSolidity option. +func (assert *Assert) solidityVerification(b backend.ID, vk interface { + NbPublicWitness() int + ExportSolidity(io.Writer) error +}, + proof any, + validPublicWitness witness.Witness) { + if !solcCheck { + return // nothing to check, will make solc fail. + } + assert.t.Helper() + + // make temp dir + tmpDir, err := os.MkdirTemp("", "gnark-solidity-check*") + assert.NoError(err) + defer os.RemoveAll(tmpDir) + + // export solidity contract + fSolidity, err := os.Create(filepath.Join(tmpDir, "gnark_verifier.sol")) + assert.NoError(err) + + err = vk.ExportSolidity(fSolidity) + assert.NoError(err) + + err = fSolidity.Close() + assert.NoError(err) + + // generate assets + // gnark-solidity-checker generate --dir tmpdir --solidity contract_g16.sol + cmd := exec.Command("gnark-solidity-checker", "generate", "--dir", tmpDir, "--solidity", "gnark_verifier.sol") + assert.t.Log("running ", cmd.String()) + out, err := cmd.CombinedOutput() + assert.NoError(err, string(out)) + + // proof to hex + var proofStr string + var optBackend string + + if b == backend.GROTH16 { + optBackend = "--groth16" + var buf bytes.Buffer + _proof := proof.(*groth16_bn254.Proof) + _, err = _proof.WriteRawTo(&buf) + assert.NoError(err) + proofStr = hex.EncodeToString(buf.Bytes()) + } else if b == backend.PLONK { + optBackend = "--plonk" + _proof := proof.(*plonk_bn254.Proof) + // TODO @gbotrel make a single Marshal function for PlonK proof. + proofStr = hex.EncodeToString(_proof.MarshalSolidity()) + } else { + panic("not implemented") + } + + // public witness to hex + bPublicWitness, err := validPublicWitness.MarshalBinary() + assert.NoError(err) + // that's quite dirty... + // first 4 bytes -> nbPublic + // next 4 bytes -> nbSecret + // next 4 bytes -> nb elements in the vector (== nbPublic + nbSecret) + bPublicWitness = bPublicWitness[12:] + publicWitnessStr := hex.EncodeToString(bPublicWitness) + + // verify proof + // gnark-solidity-checker verify --dir tmdir --groth16 --nb-public-inputs 1 --proof 1234 --public-inputs dead + cmd = exec.Command("gnark-solidity-checker", "verify", + "--dir", tmpDir, + optBackend, + "--nb-public-inputs", strconv.Itoa(vk.NbPublicWitness()), + "--proof", proofStr, + "--public-inputs", publicWitnessStr) + assert.t.Log("running ", cmd.String()) + out, err = cmd.CombinedOutput() + assert.NoError(err, string(out)) +} diff --git a/test/blueprint_solver.go b/test/blueprint_solver.go new file mode 100644 index 0000000000..9ced90ee6a --- /dev/null +++ b/test/blueprint_solver.go @@ -0,0 +1,140 @@ +package test + +import ( + "math/big" + + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/utils" +) + +// blueprintSolver is a constraint.Solver that can be used to test a circuit +// it is a separate type to avoid method collisions with the engine. +type blueprintSolver struct { + internalVariables []*big.Int + q *big.Int +} + +// implements constraint.Solver + +func (s *blueprintSolver) SetValue(vID uint32, f constraint.Element) { + if int(vID) > len(s.internalVariables) { + panic("out of bounds") + } + v := s.ToBigInt(f) + s.internalVariables[vID].Set(v) +} + +func (s *blueprintSolver) GetValue(cID, vID uint32) constraint.Element { + panic("not implemented in test.Engine") +} +func (s *blueprintSolver) GetCoeff(cID uint32) constraint.Element { + panic("not implemented in test.Engine") +} + +func (s *blueprintSolver) IsSolved(vID uint32) bool { + panic("not implemented in test.Engine") +} + +// implements constraint.Field + +func (s *blueprintSolver) FromInterface(i interface{}) constraint.Element { + b := utils.FromInterface(i) + return s.toElement(&b) +} + +func (s *blueprintSolver) ToBigInt(f constraint.Element) *big.Int { + r := new(big.Int) + fBytes := f.Bytes() + r.SetBytes(fBytes[:]) + return r +} +func (s *blueprintSolver) Mul(a, b constraint.Element) constraint.Element { + ba, bb := s.ToBigInt(a), s.ToBigInt(b) + ba.Mul(ba, bb).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Add(a, b constraint.Element) constraint.Element { + ba, bb := s.ToBigInt(a), s.ToBigInt(b) + ba.Add(ba, bb).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Sub(a, b constraint.Element) constraint.Element { + ba, bb := s.ToBigInt(a), s.ToBigInt(b) + ba.Sub(ba, bb).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Neg(a constraint.Element) constraint.Element { + ba := s.ToBigInt(a) + ba.Neg(ba).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Inverse(a constraint.Element) (constraint.Element, bool) { + ba := s.ToBigInt(a) + r := ba.ModInverse(ba, s.q) + return s.toElement(ba), r != nil +} +func (s *blueprintSolver) One() constraint.Element { + b := new(big.Int).SetUint64(1) + return s.toElement(b) +} +func (s *blueprintSolver) IsOne(a constraint.Element) bool { + b := s.ToBigInt(a) + return b.IsUint64() && b.Uint64() == 1 +} + +func (s *blueprintSolver) String(a constraint.Element) string { + b := s.ToBigInt(a) + return b.String() +} + +func (s *blueprintSolver) Uint64(a constraint.Element) (uint64, bool) { + b := s.ToBigInt(a) + return b.Uint64(), b.IsUint64() +} + +func (s *blueprintSolver) Read(calldata []uint32) (constraint.Element, int) { + // We encoded big.Int as constraint.Element on 12 uint32 words. + var r constraint.Element + for i := 0; i < len(r); i++ { + index := i * 2 + r[i] = uint64(calldata[index])<<32 | uint64(calldata[index+1]) + } + return r, len(r) * 2 +} + +func (s *blueprintSolver) toElement(b *big.Int) constraint.Element { + return bigIntToElement(b) +} + +func bigIntToElement(b *big.Int) constraint.Element { + if b.Sign() == -1 { + panic("negative value") + } + bytes := b.Bytes() + if len(bytes) > 48 { + panic("value too big") + } + var paddedBytes [48]byte + copy(paddedBytes[48-len(bytes):], bytes[:]) + + var r constraint.Element + r.SetBytes(paddedBytes) + + return r +} + +// wrappedBigInt is a wrapper around big.Int to implement the frontend.CanonicalVariable interface +type wrappedBigInt struct { + *big.Int +} + +func (w wrappedBigInt) Compress(to *[]uint32) { + // convert to Element. + e := bigIntToElement(w.Int) + + // append the uint32 words to the slice + for i := 0; i < len(e); i++ { + *to = append(*to, uint32(e[i]>>32)) + *to = append(*to, uint32(e[i]&0xffffffff)) + } +} diff --git a/test/blueprint_solver_test.go b/test/blueprint_solver_test.go new file mode 100644 index 0000000000..0cc42b6a27 --- /dev/null +++ b/test/blueprint_solver_test.go @@ -0,0 +1,64 @@ +package test + +import ( + "math/big" + "math/rand" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" +) + +func TestBigIntToElement(t *testing.T) { + t.Parallel() + // sample a random big.Int, convert it to an element, and back + // to a big.Int, and check that it's the same + s := blueprintSolver{q: ecc.BN254.ScalarField()} + b := big.NewInt(0) + for i := 0; i < 50; i++ { + b.Rand(rand.New(rand.NewSource(time.Now().Unix())), s.q) //#nosec G404 -- This is a false positive + e := s.toElement(b) + b2 := s.ToBigInt(e) + if b.Cmp(b2) != 0 { + t.Fatal("b != b2") + } + } + +} + +func TestBigIntToUint32Slice(t *testing.T) { + t.Parallel() + // sample a random big.Int, write it to a uint32 slice, and back + // to a big.Int, and check that it's the same + s := blueprintSolver{q: ecc.BN254.ScalarField()} + b1 := big.NewInt(0) + b2 := big.NewInt(0) + + for i := 0; i < 50; i++ { + b1.Rand(rand.New(rand.NewSource(time.Now().Unix())), s.q) //#nosec G404 -- This is a false positive + b2.Rand(rand.New(rand.NewSource(time.Now().Unix())), s.q) //#nosec G404 -- This is a false positive + wb1 := wrappedBigInt{b1} + wb2 := wrappedBigInt{b2} + var to []uint32 + wb1.Compress(&to) + wb2.Compress(&to) + + if len(to) != 24 { + t.Fatal("wrong length: expected 2*len of constraint.Element (uint32 words)") + } + + e1, n := s.Read(to) + if n != 12 { + t.Fatal("wrong length: expected 1 len of constraint.Element (uint32 words)") + } + e2, n := s.Read(to[n:]) + if n != 12 { + t.Fatal("wrong length: expected 1 len of constraint.Element (uint32 words)") + } + rb1, rb2 := s.ToBigInt(e1), s.ToBigInt(e2) + if rb1.Cmp(b1) != 0 || rb2.Cmp(b2) != 0 { + t.Fatal("rb1 != b1 || rb2 != b2") + } + } + +} diff --git a/test/commitments_test.go b/test/commitments_test.go new file mode 100644 index 0000000000..2b62bb8f09 --- /dev/null +++ b/test/commitments_test.go @@ -0,0 +1,260 @@ +package test + +import ( + "fmt" + "github.com/consensys/gnark/backend" + groth16 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/witness" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "reflect" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/stretchr/testify/assert" +) + +type noCommitmentCircuit struct { + X frontend.Variable +} + +func (c *noCommitmentCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.X, 1) + api.AssertIsEqual(c.X, 1) + return nil +} + +type commitmentCircuit struct { + Public []frontend.Variable `gnark:",public"` + X []frontend.Variable +} + +func (c *commitmentCircuit) Define(api frontend.API) error { + + commitment, err := api.(frontend.Committer).Commit(c.X...) + if err != nil { + return err + } + sum := frontend.Variable(0) + for i, x := range c.X { + sum = api.Add(sum, api.Mul(x, i+1)) + } + for _, p := range c.Public { + sum = api.Add(sum, p) + } + api.AssertIsDifferent(commitment, sum) + return nil +} + +type committedConstantCircuit struct { + X frontend.Variable +} + +func (c *committedConstantCircuit) Define(api frontend.API) error { + commitment, err := api.(frontend.Committer).Commit(1, c.X) + if err != nil { + return err + } + api.AssertIsDifferent(commitment, c.X) + return nil +} + +type committedPublicCircuit struct { + X frontend.Variable `gnark:",public"` +} + +func (c *committedPublicCircuit) Define(api frontend.API) error { + commitment, err := api.(frontend.Committer).Commit(c.X) + if err != nil { + return err + } + api.AssertIsDifferent(commitment, c.X) + return nil +} + +type independentCommitsCircuit struct { + X []frontend.Variable +} + +func (c *independentCommitsCircuit) Define(api frontend.API) error { + committer := api.(frontend.Committer) + for i := range c.X { + if ch, err := committer.Commit(c.X[i]); err != nil { + return err + } else { + api.AssertIsDifferent(ch, c.X[i]) + } + } + return nil +} + +type twoCommitCircuit struct { + X []frontend.Variable + Y frontend.Variable +} + +func (c *twoCommitCircuit) Define(api frontend.API) error { + c0, err := api.(frontend.Committer).Commit(c.X...) + if err != nil { + return err + } + var c1 frontend.Variable + if c1, err = api.(frontend.Committer).Commit(c0, c.Y); err != nil { + return err + } + api.AssertIsDifferent(c1, c.Y) + return nil +} + +type doubleCommitCircuit struct { + X, Y frontend.Variable +} + +func (c *doubleCommitCircuit) Define(api frontend.API) error { + var c0, c1 frontend.Variable + var err error + if c0, err = api.(frontend.Committer).Commit(c.X); err != nil { + return err + } + if c1, err = api.(frontend.Committer).Commit(c.X, c.Y); err != nil { + return err + } + api.AssertIsDifferent(c0, c1) + return nil +} + +func TestHollow(t *testing.T) { + + run := func(c, expected frontend.Circuit) func(t *testing.T) { + return func(t *testing.T) { + seen := hollow(c) + assert.Equal(t, expected, seen) + } + } + + assignments := []frontend.Circuit{ + &committedConstantCircuit{1}, + &commitmentCircuit{X: []frontend.Variable{1}, Public: []frontend.Variable{}}, + } + + expected := []frontend.Circuit{ + &committedConstantCircuit{nil}, + &commitmentCircuit{X: []frontend.Variable{nil}, Public: []frontend.Variable{}}, + } + + for i := range assignments { + t.Run(removePackageName(reflect.TypeOf(assignments[i]).String()), run(assignments[i], expected[i])) + } +} + +type commitUniquenessCircuit struct { + X []frontend.Variable +} + +func (c *commitUniquenessCircuit) Define(api frontend.API) error { + var err error + + ch := make([]frontend.Variable, len(c.X)) + for i := range c.X { + if ch[i], err = api.(frontend.Committer).Commit(c.X[i]); err != nil { + return err + } + for j := 0; j < i; j++ { + api.AssertIsDifferent(ch[i], ch[j]) + } + } + return nil +} + +func TestCommitUniquenessZerosScs(t *testing.T) { // TODO @Tabaie Randomize Groth16 commitments for real + + w, err := frontend.NewWitness(&commitUniquenessCircuit{[]frontend.Variable{0, 0}}, ecc.BN254.ScalarField()) + assert.NoError(t, err) + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &commitUniquenessCircuit{[]frontend.Variable{nil, nil}}) + assert.NoError(t, err) + + _, err = ccs.Solve(w) + assert.NoError(t, err) +} + +var commitmentTestCircuits []frontend.Circuit + +func init() { + commitmentTestCircuits = []frontend.Circuit{ + &noCommitmentCircuit{1}, + &commitmentCircuit{X: []frontend.Variable{1}, Public: []frontend.Variable{}}, // single commitment + &commitmentCircuit{X: []frontend.Variable{1, 2}, Public: []frontend.Variable{}}, // two commitments + &commitmentCircuit{X: []frontend.Variable{1, 2, 3, 4, 5}, Public: []frontend.Variable{}}, // five commitments + &commitmentCircuit{X: []frontend.Variable{0}, Public: []frontend.Variable{1}}, // single commitment single public + &commitmentCircuit{X: []frontend.Variable{0, 1, 2, 3, 4}, Public: []frontend.Variable{1, 2, 3, 4, 5}}, // five commitments five public + &committedConstantCircuit{1}, // single committed constant + &committedPublicCircuit{1}, // single committed public + &independentCommitsCircuit{X: []frontend.Variable{1, 1}}, // two independent commitments + &twoCommitCircuit{X: []frontend.Variable{1, 2}, Y: 3}, // two commitments, second depending on first + &doubleCommitCircuit{X: 1, Y: 2}, // double committing to the same variable + } +} + +func TestCommitment(t *testing.T) { + t.Parallel() + + for _, assignment := range commitmentTestCircuits { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, WithBackends(backend.GROTH16, backend.PLONK)) + } +} + +func TestCommitmentDummySetup(t *testing.T) { + t.Parallel() + + run := func(assignment frontend.Circuit) func(t *testing.T) { + return func(t *testing.T) { + // just test the prover + _cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, hollow(assignment)) + require.NoError(t, err) + _r1cs := _cs.(*cs.R1CS) + var ( + dPk, pk groth16.ProvingKey + vk groth16.VerifyingKey + w witness.Witness + ) + require.NoError(t, groth16.Setup(_r1cs, &pk, &vk)) + require.NoError(t, groth16.DummySetup(_r1cs, &dPk)) + + comparePkSizes(t, dPk, pk) + + w, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField()) + require.NoError(t, err) + _, err = groth16.Prove(_r1cs, &pk, w) + require.NoError(t, err) + } + } + + for _, assignment := range commitmentTestCircuits { + name := removePackageName(reflect.TypeOf(assignment).String()) + if c, ok := assignment.(*commitmentCircuit); ok { + name += fmt.Sprintf(":%dprivate %dpublic", len(c.X), len(c.Public)) + } + t.Run(name, run(assignment)) + } +} + +func comparePkSizes(t *testing.T, pk1, pk2 groth16.ProvingKey) { + // skipping the domain + require.Equal(t, len(pk1.G1.A), len(pk2.G1.A)) + require.Equal(t, len(pk1.G1.B), len(pk2.G1.B)) + require.Equal(t, len(pk1.G1.Z), len(pk2.G1.Z)) + require.Equal(t, len(pk1.G1.K), len(pk2.G1.K)) + + require.Equal(t, len(pk1.G2.B), len(pk2.G2.B)) + + require.Equal(t, len(pk1.InfinityA), len(pk2.InfinityA)) + require.Equal(t, len(pk1.InfinityB), len(pk2.InfinityB)) + require.Equal(t, pk1.NbInfinityA, pk2.NbInfinityA) + require.Equal(t, pk1.NbInfinityB, pk2.NbInfinityB) + + require.Equal(t, len(pk1.CommitmentKeys), len(pk2.CommitmentKeys)) // TODO @Tabaie Compare the commitment keys +} diff --git a/test/end_to_end.go b/test/end_to_end.go new file mode 100644 index 0000000000..5da06b75c9 --- /dev/null +++ b/test/end_to_end.go @@ -0,0 +1,57 @@ +package test + +import ( + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "reflect" + "strings" + "testing" +) + +func makeOpts(opt TestingOption, curves []ecc.ID) []TestingOption { + if len(curves) > 0 { + return []TestingOption{opt, WithCurves(curves[0], curves[1:]...)} + } + return []TestingOption{opt} +} + +func testPlonk(t *testing.T, assignment frontend.Circuit, curves ...ecc.ID) { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, makeOpts(WithBackends(backend.PLONK), curves)...) +} + +func testGroth16(t *testing.T, assignment frontend.Circuit, curves ...ecc.ID) { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, makeOpts(WithBackends(backend.GROTH16), curves)...) +} + +func testAll(t *testing.T, assignment frontend.Circuit) { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, WithBackends(backend.GROTH16, backend.PLONK)) +} + +// hollow takes a gnark circuit and removes all the witness data. The resulting circuit can be used for compilation purposes +// Its purpose is to make testing more convenient. For example, as opposed to SolvingSucceeded(circuit, assignment), +// one can write SolvingSucceeded(hollow(assignment), assignment), obviating the creation of a separate circuit object. +func hollow(c frontend.Circuit) frontend.Circuit { + cV := reflect.ValueOf(c).Elem() + t := reflect.TypeOf(c).Elem() + res := reflect.New(t) // a new object of the same type as c + resE := res.Elem() + resC := res.Interface().(frontend.Circuit) + + frontendVar := reflect.TypeOf((*frontend.Variable)(nil)).Elem() + + for i := 0; i < t.NumField(); i++ { + fieldT := t.Field(i).Type + if fieldT.Kind() == reflect.Slice && fieldT.Elem().Implements(frontendVar) { // create empty slices for witness slices + resE.Field(i).Set(reflect.ValueOf(make([]frontend.Variable, cV.Field(i).Len()))) + } else if fieldT != frontendVar { // copy non-witness variables + resE.Field(i).Set(cV.Field(i)) + } + } + + return resC +} + +func removePackageName(s string) string { + return s[strings.LastIndex(s, ".")+1:] +} diff --git a/test/engine.go b/test/engine.go index 09e1b05b7b..afdcd1457c 100644 --- a/test/engine.go +++ b/test/engine.go @@ -18,24 +18,27 @@ package test import ( "fmt" + "github.com/consensys/gnark/constraint" "math/big" "path/filepath" "reflect" "runtime" "strconv" "strings" + "sync/atomic" - "github.com/consensys/gnark/constraint" - + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/logger" + "golang.org/x/crypto/sha3" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/field/pool" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" ) @@ -50,25 +53,15 @@ type engine struct { q *big.Int opt backend.ProverConfig // mHintsFunctions map[hint.ID]hintFunction - constVars bool - apiWrapper ApiWrapper + constVars bool + kvstore.Store + blueprints []constraint.Blueprint + internalVariables []*big.Int } // TestEngineOption defines an option for the test engine. type TestEngineOption func(e *engine) error -// ApiWrapper defines a function which wraps the API given to the circuit. -type ApiWrapper func(frontend.API) frontend.API - -// WithApiWrapper is a test engine option which which wraps the API before -// calling the Define method in circuit. If not set, then API is not wrapped. -func WithApiWrapper(wrapper ApiWrapper) TestEngineOption { - return func(e *engine) error { - e.apiWrapper = wrapper - return nil - } -} - // SetAllVariablesAsConstants is a test engine option which makes the calls to // IsConstant() and ConstantValue() always return true. If this test engine // option is not set, then all variables are considered as non-constant, @@ -101,10 +94,10 @@ func WithBackendProverOptions(opts ...backend.ProverOption) TestEngineOption { // This is an experimental feature. func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEngineOption) (err error) { e := &engine{ - curveID: utils.FieldToCurve(field), - q: new(big.Int).Set(field), - apiWrapper: func(a frontend.API) frontend.API { return a }, - constVars: false, + curveID: utils.FieldToCurve(field), + q: new(big.Int).Set(field), + constVars: false, + Store: kvstore.New(), } for _, opt := range opts { if err := opt(e); err != nil { @@ -133,8 +126,13 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng log := logger.Logger() log.Debug().Msg("running circuit in test engine") cptAdd, cptMul, cptSub, cptToBinary, cptFromBinary, cptAssertIsEqual = 0, 0, 0, 0, 0, 0 - api := e.apiWrapper(e) - err = c.Define(api) + if err = c.Define(e); err != nil { + return fmt.Errorf("define: %w", err) + } + if err = callDeferred(e); err != nil { + return fmt.Errorf("deferred: %w", err) + } + log.Debug().Uint64("add", cptAdd). Uint64("sub", cptSub). Uint64("mul", cptMul). @@ -145,14 +143,23 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng return } +func callDeferred(builder *engine) error { + for i := 0; i < len(circuitdefer.GetAll[func(frontend.API) error](builder)); i++ { + if err := circuitdefer.GetAll[func(frontend.API) error](builder)[i](builder); err != nil { + return fmt.Errorf("defer fn %d: %w", i, err) + } + } + return nil +} + var cptAdd, cptMul, cptSub, cptToBinary, cptFromBinary, cptAssertIsEqual uint64 func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - cptAdd++ + atomic.AddUint64(&cptAdd, 1) res := new(big.Int) res.Add(e.toBigInt(i1), e.toBigInt(i2)) for i := 0; i < len(in); i++ { - cptAdd++ + atomic.AddUint64(&cptAdd, 1) res.Add(res, e.toBigInt(in[i])) } res.Mod(res, e.modulus()) @@ -163,19 +170,20 @@ func (e *engine) MulAcc(a, b, c frontend.Variable) frontend.Variable { bc := pool.BigInt.Get() bc.Mul(e.toBigInt(b), e.toBigInt(c)) + res := new(big.Int) _a := e.toBigInt(a) - _a.Add(_a, bc).Mod(_a, e.modulus()) + res.Add(_a, bc).Mod(res, e.modulus()) pool.BigInt.Put(bc) - return _a + return res } func (e *engine) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - cptSub++ + atomic.AddUint64(&cptSub, 1) res := new(big.Int) res.Sub(e.toBigInt(i1), e.toBigInt(i2)) for i := 0; i < len(in); i++ { - cptSub++ + atomic.AddUint64(&cptSub, 1) res.Sub(res, e.toBigInt(in[i])) } res.Mod(res, e.modulus()) @@ -190,7 +198,7 @@ func (e *engine) Neg(i1 frontend.Variable) frontend.Variable { } func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - cptMul++ + atomic.AddUint64(&cptMul, 1) b2 := e.toBigInt(i2) if len(in) == 0 && b2.IsUint64() && b2.Uint64() <= 1 { // special path to avoid useless allocations @@ -204,7 +212,7 @@ func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend res.Mul(b1, b2) res.Mod(res, e.modulus()) for i := 0; i < len(in); i++ { - cptMul++ + atomic.AddUint64(&cptMul, 1) res.Mul(res, e.toBigInt(in[i])) res.Mod(res, e.modulus()) } @@ -244,7 +252,7 @@ func (e *engine) Inverse(i1 frontend.Variable) frontend.Variable { } func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { - cptToBinary++ + atomic.AddUint64(&cptToBinary, 1) nbBits := e.FieldBitLen() if len(n) == 1 { nbBits = n[0] @@ -276,7 +284,7 @@ func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { } func (e *engine) FromBinary(v ...frontend.Variable) frontend.Variable { - cptFromBinary++ + atomic.AddUint64(&cptFromBinary, 1) bits := make([]bool, len(v)) for i := 0; i < len(v); i++ { be := e.toBigInt(v[i]) @@ -373,7 +381,7 @@ func (e *engine) Cmp(i1, i2 frontend.Variable) frontend.Variable { } func (e *engine) AssertIsEqual(i1, i2 frontend.Variable) { - cptAssertIsEqual++ + atomic.AddUint64(&cptAssertIsEqual, 1) b1, b2 := e.toBigInt(i1), e.toBigInt(i2) if b1.Cmp(b2) != 0 { panic(fmt.Sprintf("[assertIsEqual] %s == %s", b1.String(), b2.String())) @@ -450,7 +458,7 @@ func (e *engine) print(sbb *strings.Builder, x interface{}) { } } -func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (e *engine) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") @@ -481,6 +489,14 @@ func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Vari return out, nil } +func (e *engine) NewHintForId(id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + if f := solver.GetRegisteredHint(id); f != nil { + return e.NewHint(f, nbOutputs, inputs...) + } + + return nil, fmt.Errorf("no hint registered with id #%d. Use solver.RegisterHint or solver.RegisterNamedHint", id) +} + // IsConstant returns true if v is a constant known at compile time func (e *engine) IsConstant(v frontend.Variable) bool { return e.constVars @@ -516,7 +532,7 @@ func (e *engine) toBigInt(i1 frontend.Variable) *big.Int { } } -// bitLen returns the number of bits needed to represent a fr.Element +// FieldBitLen returns the number of bits needed to represent a fr.Element func (e *engine) FieldBitLen() int { return e.q.BitLen() } @@ -587,7 +603,80 @@ func (e *engine) Compiler() frontend.Compiler { } func (e *engine) Commit(v ...frontend.Variable) (frontend.Variable, error) { - panic("not implemented") + nb := (e.FieldBitLen() + 7) / 8 + buf := make([]byte, nb) + hasher := sha3.NewCShake128(nil, []byte("gnark test engine")) + for i := range v { + vs := e.toBigInt(v[i]) + bs := vs.FillBytes(buf) + hasher.Write(bs) + } + hasher.Read(buf) + res := new(big.Int).SetBytes(buf) + res.Mod(res, e.modulus()) + return res, nil +} + +func (e *engine) Defer(cb func(frontend.API) error) { + circuitdefer.Put(e, cb) +} + +// AddInstruction is used to add custom instructions to the constraint system. +// In constraint system, this is asynchronous. In here, we do it synchronously. +func (e *engine) AddInstruction(bID constraint.BlueprintID, calldata []uint32) []uint32 { + blueprint := e.blueprints[bID].(constraint.BlueprintSolvable) + + // create a dummy instruction + inst := constraint.Instruction{ + Calldata: calldata, + WireOffset: uint32(len(e.internalVariables)), + } + + // blueprint declared nbOutputs; add as many internal variables + // and return their indices + nbOutputs := blueprint.NbOutputs(inst) + var r []uint32 + for i := 0; i < nbOutputs; i++ { + r = append(r, uint32(len(e.internalVariables))) + e.internalVariables = append(e.internalVariables, new(big.Int)) + } + + // solve the blueprint synchronously + s := blueprintSolver{ + internalVariables: e.internalVariables, + q: e.q, + } + if err := blueprint.Solve(&s, inst); err != nil { + panic(err) + } + + return r +} + +// AddBlueprint adds a custom blueprint to the constraint system. +func (e *engine) AddBlueprint(b constraint.Blueprint) constraint.BlueprintID { + if _, ok := b.(constraint.BlueprintSolvable); !ok { + panic("unsupported blueprint in test engine") + } + e.blueprints = append(e.blueprints, b) + return constraint.BlueprintID(len(e.blueprints) - 1) +} + +// InternalVariable returns the value of an internal variable. This is used in custom blueprints. +// The variableID is the index of the variable in the internalVariables slice, as +// filled by AddInstruction. +func (e *engine) InternalVariable(vID uint32) frontend.Variable { + if vID >= uint32(len(e.internalVariables)) { + panic("internal variable not found") + } + return new(big.Int).Set(e.internalVariables[vID]) +} + +// ToCanonicalVariable converts a frontend.Variable to a frontend.CanonicalVariable +// this is used in custom blueprints to return a variable than can be encoded in blueprints +func (e *engine) ToCanonicalVariable(v frontend.Variable) frontend.CanonicalVariable { + r := e.toBigInt(v) + return wrappedBigInt{r} } func (e *engine) SetGkrInfo(info constraint.GkrInfo) error { diff --git a/test/engine_test.go b/test/engine_test.go index 26b232b46e..16ed830c12 100644 --- a/test/engine_test.go +++ b/test/engine_test.go @@ -6,7 +6,9 @@ import ( "testing" "github.com/consensys/gnark" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/bits" ) @@ -16,24 +18,24 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.Compiler().NewHint(bits.IthBit, 1, circuit.A, 3) + res, err := api.Compiler().NewHint(bits.GetHints()[0], 1, circuit.A, 3) if err != nil { return fmt.Errorf("IthBit circuitA 3: %w", err) } a3b := res[0] - res, err = api.Compiler().NewHint(bits.IthBit, 1, circuit.A, 25) + res, err = api.Compiler().NewHint(bits.GetHints()[0], 1, circuit.A, 25) if err != nil { return fmt.Errorf("IthBit circuitA 25: %w", err) } a25b := res[0] - res, err = api.Compiler().NewHint(hint.InvZero, 1, circuit.A) + res, err = api.Compiler().NewHint(solver.InvZeroHint, 1, circuit.A) if err != nil { return fmt.Errorf("IsZero CircuitA: %w", err) } aInvZero := res[0] - res, err = api.Compiler().NewHint(hint.InvZero, 1, circuit.B) + res, err = api.Compiler().NewHint(solver.InvZeroHint, 1, circuit.B) if err != nil { return fmt.Errorf("IsZero, CircuitB") } @@ -69,3 +71,33 @@ func TestBuiltinHints(t *testing.T) { } } + +var isDeferCalled bool + +type EmptyCircuit struct { + X frontend.Variable +} + +func (c *EmptyCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.X, 0) + api.Compiler().Defer(func(api frontend.API) error { + isDeferCalled = true + return nil + }) + return nil +} + +func TestPreCompileHook(t *testing.T) { + c := &EmptyCircuit{} + w := &EmptyCircuit{ + X: 0, + } + isDeferCalled = false + err := IsSolved(c, w, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } + if !isDeferCalled { + t.Error("callback not called") + } +} diff --git a/test/fuzz.go b/test/fuzz.go index a4c59eb69f..60a43a7471 100644 --- a/test/fuzz.go +++ b/test/fuzz.go @@ -74,7 +74,7 @@ func zeroFiller(w frontend.Circuit, curve ecc.ID) { } func binaryFiller(w frontend.Circuit, curve ecc.ID) { - mrand.Seed(time.Now().Unix()) + mrand := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here fill(w, func() interface{} { return int(mrand.Uint32() % 2) //#nosec G404 weak rng is fine here @@ -83,7 +83,7 @@ func binaryFiller(w frontend.Circuit, curve ecc.ID) { func seedFiller(w frontend.Circuit, curve ecc.ID) { - mrand.Seed(time.Now().Unix()) + mrand := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here m := curve.ScalarField() @@ -96,8 +96,6 @@ func seedFiller(w frontend.Circuit, curve ecc.ID) { func randomFiller(w frontend.Circuit, curve ecc.ID) { - mrand.Seed(time.Now().Unix()) - r := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here m := curve.ScalarField() diff --git a/test/hint_test.go b/test/hint_test.go new file mode 100644 index 0000000000..d2c6650f7f --- /dev/null +++ b/test/hint_test.go @@ -0,0 +1,58 @@ +package test + +import ( + "fmt" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "math/big" + "testing" +) + +const id = solver.HintID(123454321) + +func identityHint(_ *big.Int, in, out []*big.Int) error { + if len(in) != len(out) { + return fmt.Errorf("len(in) = %d ≠ %d = len(out)", len(in), len(out)) + } + for i := range in { + out[i].Set(in[i]) + } + return nil +} + +type customNamedHintCircuit struct { + X []frontend.Variable +} + +func (c *customNamedHintCircuit) Define(api frontend.API) error { + y, err := api.Compiler().NewHintForId(id, len(c.X), c.X...) + + if err != nil { + return err + } + for i := range y { + api.AssertIsEqual(c.X[i], y[i]) + } + + return nil +} + +var assignment customNamedHintCircuit + +func init() { + solver.RegisterNamedHint(identityHint, id) + assignment = customNamedHintCircuit{X: []frontend.Variable{1, 2, 3, 4, 5}} +} + +func TestHintWithCustomNamePlonk(t *testing.T) { + testPlonk(t, &assignment) +} + +func TestHintWithCustomNameGroth16(t *testing.T) { + testGroth16(t, &assignment) +} + +func TestHintWithCustomNameEngine(t *testing.T) { + circuit := hollow(&assignment) + NewAssert(t).SolvingSucceeded(circuit, &assignment) +} diff --git a/test/kzg_srs.go b/test/kzg_srs.go index 4002762919..c67b89be7f 100644 --- a/test/kzg_srs.go +++ b/test/kzg_srs.go @@ -51,7 +51,6 @@ func NewKZGSRS(ccs constraint.ConstraintSystem) (kzg.SRS, error) { } return newKZGSRS(utils.FieldToCurve(ccs.Field()), kzgSize) - } var srsCache map[ecc.ID]kzg.SRS diff --git a/test/options.go b/test/options.go index bcdae0eee8..65aa8a00c2 100644 --- a/test/options.go +++ b/test/options.go @@ -19,10 +19,11 @@ package test import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) -// TestingOption defines option for altering the behaviour of Assert methods. +// TestingOption defines option for altering the behavior of Assert methods. // See the descriptions of functions returning instances of this type for // particular options. type TestingOption func(*testingConfig) error @@ -31,9 +32,11 @@ type testingConfig struct { backends []backend.ID curves []ecc.ID witnessSerialization bool + solverOpts []solver.Option proverOpts []backend.ProverOption compileOpts []frontend.CompileOption fuzzing bool + solidity bool } // WithBackends is testing option which restricts the backends the assertions are @@ -83,6 +86,16 @@ func WithProverOpts(proverOpts ...backend.ProverOption) TestingOption { } } +// WithSolverOpts is a testing option which uses the given solverOpts when +// calling constraint system solver. +func WithSolverOpts(solverOpts ...solver.Option) TestingOption { + return func(opt *testingConfig) error { + opt.proverOpts = append(opt.proverOpts, backend.WithSolverOptions(solverOpts...)) + opt.solverOpts = solverOpts + return nil + } +} + // WithCompileOpts is a testing option which uses the given compileOpts when // calling frontend.Compile in assertions. func WithCompileOpts(compileOpts ...frontend.CompileOption) TestingOption { @@ -91,3 +104,16 @@ func WithCompileOpts(compileOpts ...frontend.CompileOption) TestingOption { return nil } } + +// WithSolidity is a testing option which enables solidity tests in assertions. +// If the build tag "solccheck" is not set, this option is ignored. +// When the tag is set; this requires gnark-solidity-checker to be installed, which in turns +// requires solc and abigen to be reachable in the PATH. +// +// See https://github.com/ConsenSys/gnark-solidity-checker for more details. +func WithSolidity() TestingOption { + return func(opt *testingConfig) error { + opt.solidity = true && solcCheck + return nil + } +} diff --git a/test/solccheck_with.go b/test/solccheck_with.go new file mode 100644 index 0000000000..9bf23579e1 --- /dev/null +++ b/test/solccheck_with.go @@ -0,0 +1,5 @@ +//go:build solccheck + +package test + +const solcCheck = true diff --git a/test/solccheck_without.go b/test/solccheck_without.go new file mode 100644 index 0000000000..a6414c27de --- /dev/null +++ b/test/solccheck_without.go @@ -0,0 +1,5 @@ +//go:build !solccheck + +package test + +const solcCheck = false diff --git a/test/solver_test.go b/test/solver_test.go index 0c11baed7b..88e0d39b92 100644 --- a/test/solver_test.go +++ b/test/solver_test.go @@ -1,23 +1,24 @@ package test import ( - "errors" "fmt" + "io" "math/big" "reflect" "strconv" "strings" "testing" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" - cs "github.com/consensys/gnark/constraint/tinyfield" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/internal/backend/circuits" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" ) @@ -25,6 +26,11 @@ import ( // ignore witness size larger than this bound const permutterBound = 3 +// r1cs + sparser1cs +const nbSystems = 2 + +var builders [2]frontend.NewBuilder + func TestSolverConsistency(t *testing.T) { if testing.Short() { t.Skip("skipping R1CS solver test with testing.Short() flag set") @@ -49,15 +55,59 @@ func TestSolverConsistency(t *testing.T) { } } -type permutter struct { - circuit frontend.Circuit - r1cs *cs.R1CS - scs *cs.SparseR1CS - witness []tinyfield.Element - hints []hint.Function +// witness used for the permutter. It implements the Witness interface +// using mock methods (only the underlying vector is required). +type permutterWitness struct { + vector any +} + +func (pw *permutterWitness) WriteTo(w io.Writer) (int64, error) { + return 0, nil +} + +func (pw *permutterWitness) ReadFrom(r io.Reader) (int64, error) { + return 0, nil +} + +func (pw *permutterWitness) MarshalBinary() ([]byte, error) { + return nil, nil +} - // used to avoid allocations in R1CS solver - a, b, c []tinyfield.Element +func (pw *permutterWitness) UnmarshalBinary([]byte) error { + return nil +} + +func (pw *permutterWitness) Public() (witness.Witness, error) { + return pw, nil +} + +func (pw *permutterWitness) Vector() any { + return pw.vector +} + +func (pw *permutterWitness) ToJSON(s *schema.Schema) ([]byte, error) { + return nil, nil +} + +func (pw *permutterWitness) FromJSON(s *schema.Schema, data []byte) error { + return nil +} + +func (pw *permutterWitness) Fill(nbPublic, nbSecret int, values <-chan any) error { + return nil +} + +func newPermutterWitness(pv tinyfield.Vector) witness.Witness { + return &permutterWitness{ + vector: pv, + } +} + +type permutter struct { + circuit frontend.Circuit + constraintSystems [2]constraint.ConstraintSystem + witness []tinyfield.Element + hints []solver.Hint } // note that circuit will be mutated and this is not thread safe @@ -66,30 +116,36 @@ func (p *permutter) permuteAndTest(index int) error { for i := 0; i < len(tinyfieldElements); i++ { p.witness[index].SetUint64(tinyfieldElements[i]) if index == len(p.witness)-1 { + // we have a unique permutation + var errorSystems [2]error + var errorEngines [2]error + + // 2 constraints systems + for k := 0; k < nbSystems; k++ { - // solve the cs using R1CS solver - errR1CS := p.solveR1CS() - errSCS := p.solveSCS() + errorSystems[k] = p.solve(k) - // solve the cs using test engine - // first copy the witness in the circuit - copyWitnessFromVector(p.circuit, p.witness) - errEngine1 := isSolvedEngine(p.circuit, tinyfield.Modulus()) + // solve the cs using test engine + // first copy the witness in the circuit + copyWitnessFromVector(p.circuit, p.witness) + errorEngines[0] = isSolvedEngine(p.circuit, tinyfield.Modulus()) - copyWitnessFromVector(p.circuit, p.witness) - errEngine2 := isSolvedEngine(p.circuit, tinyfield.Modulus(), SetAllVariablesAsConstants()) + copyWitnessFromVector(p.circuit, p.witness) + errorEngines[1] = isSolvedEngine(p.circuit, tinyfield.Modulus(), SetAllVariablesAsConstants()) - if (errR1CS == nil) != (errEngine1 == nil) || - (errSCS == nil) != (errEngine1 == nil) || - (errEngine1 == nil) != (errEngine2 == nil) { + } + if (errorSystems[0] == nil) != (errorEngines[0] == nil) || + (errorSystems[1] == nil) != (errorEngines[0] == nil) || + (errorEngines[0] == nil) != (errorEngines[1] == nil) { return fmt.Errorf("errSCS :%s\nerrR1CS :%s\nerrEngine(const=false): %s\nerrEngine(const=true): %s\nwitness: %s", - formatError(errSCS), - formatError(errR1CS), - formatError(errEngine1), - formatError(errEngine2), + formatError(errorSystems[0]), + formatError(errorSystems[1]), + formatError(errorEngines[0]), + formatError(errorEngines[1]), formatWitness(p.witness)) } + } else { // recurse if err := p.permuteAndTest(index + 1); err != nil { @@ -123,38 +179,19 @@ func formatWitness(witness []tinyfield.Element) string { return sbb.String() } -func (p *permutter) solveSCS() error { - opt, err := backend.NewProverConfig(backend.WithHints(p.hints...)) - if err != nil { - return err - } - - _, err = p.scs.Solve(p.witness, opt) - return err -} - -func (p *permutter) solveR1CS() error { - opt, err := backend.NewProverConfig(backend.WithHints(p.hints...)) - if err != nil { - return err - } - - for i := 0; i < len(p.r1cs.Constraints); i++ { - p.a[i].SetZero() - p.b[i].SetZero() - p.c[i].SetZero() - } - _, err = p.r1cs.Solve(p.witness, p.a, p.b, p.c, opt) +func (p *permutter) solve(i int) error { + pw := newPermutterWitness(p.witness) + _, err := p.constraintSystems[i].Solve(pw, solver.WithHints(p.hints...)) return err } // isSolvedEngine behaves like test.IsSolved except it doesn't clone the circuit func isSolvedEngine(c frontend.Circuit, field *big.Int, opts ...TestEngineOption) (err error) { e := &engine{ - curveID: utils.FieldToCurve(field), - q: new(big.Int).Set(field), - apiWrapper: func(a frontend.API) frontend.API { return a }, - constVars: false, + curveID: utils.FieldToCurve(field), + q: new(big.Int).Set(field), + constVars: false, + Store: kvstore.New(), } for _, opt := range opts { if err := opt(e); err != nil { @@ -168,8 +205,12 @@ func isSolvedEngine(c frontend.Circuit, field *big.Int, opts ...TestEngineOption } }() - api := e.apiWrapper(e) - err = c.Define(api) + if err = c.Define(e); err != nil { + return fmt.Errorf("define: %w", err) + } + if err = callDeferred(e); err != nil { + return fmt.Errorf("") + } return } @@ -180,7 +221,7 @@ func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { i := 0 schema.Walk(to, tVariable, func(f schema.LeafInfo, tInput reflect.Value) error { if f.Visibility == schema.Public { - tInput.Set(reflect.ValueOf((from[i]))) + tInput.Set(reflect.ValueOf(from[i])) i++ } return nil @@ -188,7 +229,7 @@ func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { schema.Walk(to, tVariable, func(f schema.LeafInfo, tInput reflect.Value) error { if f.Visibility == schema.Secret { - tInput.Set(reflect.ValueOf((from[i]))) + tInput.Set(reflect.ValueOf(from[i])) i++ } return nil @@ -198,41 +239,30 @@ func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { // ConsistentSolver solves given circuit with all possible witness combinations using internal/tinyfield // // Since the goal of this method is to flag potential solver issues, it is not exposed as an API for now -func consistentSolver(circuit frontend.Circuit, hintFunctions []hint.Function) error { +func consistentSolver(circuit frontend.Circuit, hintFunctions []solver.Hint) error { p := permutter{ circuit: circuit, hints: hintFunctions, } - // compile R1CS - ccs, err := frontend.Compile(tinyfield.Modulus(), r1cs.NewBuilder, circuit) - if err != nil { - return err - } + // compile the systems + for i := 0; i < nbSystems; i++ { - p.r1cs = ccs.(*cs.R1CS) - - // witness len - n := len(p.r1cs.Public) - 1 + len(p.r1cs.Secret) - if n > permutterBound { - return nil - } - - p.a = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) - p.b = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) - p.c = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) - p.witness = make([]tinyfield.Element, n) + ccs, err := frontend.Compile(tinyfield.Modulus(), builders[i], circuit) + if err != nil { + return err + } + p.constraintSystems[i] = ccs - // compile SparseR1CS - ccs, err = frontend.Compile(tinyfield.Modulus(), scs.NewBuilder, circuit) - if err != nil { - return err - } + if i == 0 { // the -1 is only for r1cs... + n := ccs.GetNbPublicVariables() - 1 + ccs.GetNbSecretVariables() + if n > permutterBound { + return nil + } + p.witness = make([]tinyfield.Element, n) + } - p.scs = ccs.(*cs.SparseR1CS) - if (len(p.scs.Public) + len(p.scs.Secret)) != n { - return errors.New("mismatch of witness size for same circuit") } return p.permuteAndTest(0) @@ -247,4 +277,7 @@ func init() { for i := uint64(0); i < n; i++ { tinyfieldElements[i] = i } + + builders[0] = r1cs.NewBuilder + builders[1] = scs.NewBuilder }