diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e717f60c0a..02f542f8db 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,13 +1,13 @@ # Description -Please include a summary of the changes and the related issue. Please also include relevant motivation and context. -Note that in most cases the PR should be against the `develop` branch. + Fixes # (issue) ## Type of change -Please delete options that are not relevant. + - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) @@ -16,14 +16,14 @@ Please delete options that are not relevant. # How has this been tested? -Please describe the tests that you ran or implemented to verify your changes. Provide instructions so we can reproduce. + - [ ] Test A - [ ] Test B # How has this been benchmarked? -Please describe the benchmarks that you ran to verify your changes. + - [ ] Benchmark A, on Macbook pro M1, 32GB RAM - [ ] Benchmark B, on x86 Intel xxx, 16GB RAM diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index efd14e8cbe..c0e2c60046 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -5,14 +5,14 @@ jobs: runs-on: ubuntu-latest steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: - go-version: 1.19.x + go-version: 1.20.x - name: checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -22,10 +22,6 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - args: --timeout=5m - name: install deps run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: gofmt @@ -35,24 +31,35 @@ jobs: go generate ./... git update-index --assume-unchanged go.mod git update-index --assume-unchanged go.sum - if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after runing go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after running go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + + # hack to ensure golanglint process generated files + - name: remove "generated by" comments from generated files + run: | + find . -type f -name '*.go' -exec sed -i 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + # on macos: find . -type f -name '*.go' -exec sed -i '' -E 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + args: --timeout=5m test: strategy: matrix: - go-version: [1.19.x] + go-version: [1.20.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} needs: - staticcheck steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: checkout code - uses: actions/checkout@v2 - - uses: actions/cache@v2 + uses: actions/checkout@v3 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -63,10 +70,19 @@ jobs: restore-keys: | ${{ runner.os }}-go- - name: install deps - run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest + run: | + go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest + go install github.com/ethereum/go-ethereum/cmd/abigen@v1.12.0 + go install github.com/consensys/gnark-solidity-checker@latest + sudo add-apt-repository ppa:ethereum/ethereum + sudo apt-get update + sudo apt-get install solc - name: Test run: | - go test -v -short -timeout=30m ./... + go test -v -short -tags=solccheck -timeout=30m ./... + - name: Test race + run: | + go test -v -short -race -timeout=30m slack-workflow-status-failed: if: failure() @@ -78,7 +94,7 @@ jobs: steps: - name: Notify slack -- workflow failed id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { @@ -86,7 +102,8 @@ jobs: "repo": "${{ github.repository }}", "status": "FAIL", "title": "${{ github.event.pull_request.title }}", - "pr": "${{ github.event.pull_request.head.ref }}" + "pr": "${{ github.event.pull_request.head.ref }}", + "failed_step_url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}/" } env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} @@ -101,7 +118,7 @@ jobs: steps: - name: Notify slack -- workflow succeeded id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 3ef5179a51..d894abb043 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -9,14 +9,14 @@ jobs: runs-on: ubuntu-latest steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: - go-version: 1.19.x + go-version: 1.20.x - name: checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -26,10 +26,8 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - args: --timeout=5m + + - name: install deps run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: gofmt @@ -39,24 +37,36 @@ jobs: go generate ./... git update-index --assume-unchanged go.mod git update-index --assume-unchanged go.sum - if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after runing go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + if [[ -n $(git status --porcelain) ]]; then echo "git repo is dirty after running go generate -- please don't modify generated files"; echo $(git diff);echo $(git status --porcelain); exit 1; fi + + # hack to ensure golanglint process generated files + - name: remove "generated by" comments from generated files + run: | + find . -type f -name '*.go' -exec sed -i 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + # on macos: find . -type f -name '*.go' -exec sed -i '' -E 's/Code generated by .* DO NOT EDIT/FOO/g' {} \; + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + args: --timeout=5m + test: strategy: matrix: - go-version: [1.19.x] + go-version: [1.20.x] os: [ubuntu-latest, windows-latest, macos-latest] runs-on: ${{ matrix.os }} needs: - staticcheck steps: - name: install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: checkout code - uses: actions/checkout@v2 - - uses: actions/cache@v2 + uses: actions/checkout@v3 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -86,7 +96,7 @@ jobs: steps: - name: Notify slack -- workflow failed id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { @@ -94,7 +104,8 @@ jobs: "repo": "${{ github.repository }}", "status": "FAIL", "title": "push to ${{ github.event.push.base_ref }}", - "pr": "" + "pr": "", + "failed_step_url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}/" } env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} @@ -109,7 +120,7 @@ jobs: steps: - name: Notify slack -- workflow succeeded id: slack - uses: slackapi/slack-github-action@v1.19.0 + uses: slackapi/slack-github-action@v1.23.0 with: payload: | { diff --git a/.gitignore b/.gitignore index 5f2754da70..1be9af1ba2 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,9 @@ gnarkd/circuits/** # Jetbrains stuff .idea/ + +# go workspace +go.work +go.work.sum + +examples/gbotrel/** \ No newline at end of file diff --git a/README.md b/README.md index 91e014ab2c..c43caee473 100644 --- a/README.md +++ b/README.md @@ -122,17 +122,17 @@ If you use `gnark` in your research a citation would be appreciated. Please use the following BibTeX to cite the most recent release. ```bib -@software{gnark-v0.8.0, +@software{gnark-v0.9.0, author = {Gautam Botrel and Thomas Piellard and Youssef El Housni and Ivo Kubjas and Arya Tabaie}, - title = {ConsenSys/gnark: v0.8.0}, + title = {ConsenSys/gnark: v0.9.0}, month = feb, year = 2023, publisher = {Zenodo}, - version = {v0.8.0}, + version = {v0.9.0}, doi = {10.5281/zenodo.5819104}, url = {https://doi.org/10.5281/zenodo.5819104} } diff --git a/backend/backend.go b/backend/backend.go index 73c3f05bae..84a05271ab 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -15,11 +15,7 @@ // Package backend implements Zero Knowledge Proof systems: it consumes circuit compiled with gnark/frontend. package backend -import ( - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/logger" - "github.com/rs/zerolog" -) +import "github.com/consensys/gnark/constraint/solver" // ID represent a unique ID for a proving scheme type ID uint16 @@ -50,26 +46,20 @@ func (id ID) String() string { } } -// ProverOption defines option for altering the behaviour of the prover in +// ProverOption defines option for altering the behavior of the prover in // Prove, ReadAndProve and IsSolved methods. See the descriptions of functions // returning instances of this type for implemented options. type ProverOption func(*ProverConfig) error // ProverConfig is the configuration for the prover with the options applied. type ProverConfig struct { - Force bool // defaults to false - HintFunctions map[hint.ID]hint.Function // defaults to all built-in hint functions - CircuitLogger zerolog.Logger // defaults to gnark.Logger + SolverOpts []solver.Option } // NewProverConfig returns a default ProverConfig with given prover options opts // applied. func NewProverConfig(opts ...ProverOption) (ProverConfig, error) { - log := logger.Logger() - opt := ProverConfig{CircuitLogger: log, HintFunctions: make(map[hint.ID]hint.Function)} - for _, v := range hint.GetRegistered() { - opt.HintFunctions[hint.UUID(v)] = v - } + opt := ProverConfig{} for _, option := range opts { if err := option(&opt); err != nil { return ProverConfig{}, err @@ -78,42 +68,10 @@ func NewProverConfig(opts ...ProverOption) (ProverConfig, error) { return opt, nil } -// IgnoreSolverError is a prover option that indicates that the Prove algorithm -// should complete even if constraint system is not solved. In that case, Prove -// will output an invalid Proof, but will execute all algorithms which is useful -// for test and benchmarking purposes. -func IgnoreSolverError() ProverOption { - return func(opt *ProverConfig) error { - opt.Force = true - return nil - } -} - -// WithHints is a prover option that specifies additional hint functions to be used -// by the constraint solver. -func WithHints(hintFunctions ...hint.Function) ProverOption { - log := logger.Logger() - return func(opt *ProverConfig) error { - // it is an error to register hint function several times, but as the - // prover already checks it then omit here. - for _, h := range hintFunctions { - uuid := hint.UUID(h) - if _, ok := opt.HintFunctions[uuid]; ok { - log.Warn().Int("hintID", int(uuid)).Str("name", hint.Name(h)).Msg("duplicate hint function") - } else { - opt.HintFunctions[uuid] = h - } - } - return nil - } -} - -// WithCircuitLogger is a prover option that specifies zerolog.Logger as a destination for the -// logs printed by api.Println(). By default, uses gnark/logger. -// zerolog.Nop() will disable logging -func WithCircuitLogger(l zerolog.Logger) ProverOption { +// WithSolverOptions specifies the constraint system solver options. +func WithSolverOptions(solverOpts ...solver.Option) ProverOption { return func(opt *ProverConfig) error { - opt.CircuitLogger = l + opt.SolverOpts = solverOpts return nil } } diff --git a/backend/groth16/bellman_test.go b/backend/groth16/bellman_test.go index af74b4e47d..eb4649d34b 100644 --- a/backend/groth16/bellman_test.go +++ b/backend/groth16/bellman_test.go @@ -23,62 +23,62 @@ func TestVerifyBellmanProof(t *testing.T) { ok bool }{ { - "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6w==", + "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6wAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lvQLU/KqgFhsLkt/5C/scqs7nWR+eYtyPdWiLVBux9GblT4AhHYMdCgwQfSJcudvsgV6fXoK+DUSRgJ++Nqt+Wvb7GlYlHpxCysQhz26TTu8Nyo7zpmVPH92+UYmbvbQCSvX2BhWtvkfHmqDVjmSIQ4RUMfeveA1KZbSf999NE4qKK8Do+8oXcmTM4LZVmh1rlyqznIdFXPN7x3pD4E0gb6/y69xtWMChv9654FMg05bAdueKt9uA4BEcAbpkdHF", "LcMT3OOlkHLzJBKCKjjzzVMg+r+FVgd52LlhZPB4RFg=", true, }, { - "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6w==", + "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6wAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lvQLU/KqgFhsLkt/5C/scqs7nWR+eYtyPdWiLVBux9GblT4AhHYMdCgwQfSJcudvsgV6fXoK+DUSRgJ++Nqt+Wvb7GlYlHpxCysQhz26TTu8Nyo7zpmVPH92+UYmbvbQCSvX2BhWtvkfHmqDVjmSIQ4RUMfeveA1KZbSf999NE4qKK8Do+8oXcmTM4LZVmh1rlyqznIdFXPN7x3pD4E0gb6/y69xtWMChv9654FMg05bAdueKt9uA4BEcAbpkdHF", "cmzVCcRVnckw3QUPhmG4Bkppeg4K50oDQwQ9EH+Fq1s=", false, }, { - "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6w==", + "hwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bhwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bo5ViaDBdO7ZBxAhLSe5k/5TFQyF5Lv7KN2tLKnwgoWMqB16OL8WdbePIwTCuPtJNAFKoTZylLDbSf02kckMcZQDPF9iGh+JC99Pio74vDpwTEjUx5tQ99gNQwxULtztsqDRsPnEvKvLmsxHt8LQVBkEBm2PBJFY+OXf1MNW021viDBpR10mX4WQ6zrsGL5L0GY4cwf4tlbh+Obit+LnN/SQTnREf8fPpdKZ1sa/ui3pGi8lMT6io4D7Ujlwx2RdChwk883gUlTKCyXYA6XWZa8H9/xKIYZaJ0xEs0M5hQOMxiGpxocuX/8maSDmeCk3bkBF+isfMf77HCEGsZANw0hSrO2FGg14Sl26xLAIohdaW8O7gEaag8JdVAZ3OVLd5Df1NkZBEr753Xb8WwaXsJjE7qxwINL1KdqA4+EiYW4edb7+a9bbBeOPtb67ZxmFqAAAAAoMkzUv+KG8WoXszZI5NNMrbMLBDYP/xHunVgSWcix/kBrGlNozv1uFr0cmYZiij3YqToYs+EZa3dl2ILHx7H1n+b+Bjky/td2QduHVtf5t/Z9sKCfr+vOn12zVvOVz/6wAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lvQLU/KqgFhsLkt/5C/scqs7nWR+eYtyPdWiLVBux9GblT4AhHYMdCgwQfSJcudvsgV6fXoK+DUSRgJ++Nqt+Wvb7GlYlHpxCysQhz26TTu8Nyo7zpmVPH92+UYmbvbQCSvX2BhWtvkfHmqDVjmSIQ4RUMfeveA1KZbSf999NE4qKK8Do+8oXcmTM4LZVmh1rlyqznIdFXPN7x3pD4E0gb6/y69xtWMChv9654FMg05bAdueKt9uA4BEcAbpkdHF", "cmzVCcRVnckw3QUPhmG4Bkppeg4K50oDQwQ9EH+Fq1s=", false, }, { - "kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6knkzSwcsialcheg69eZYPK8EzKRVI5FrRHKi8rgB+R5jyPV70ejmYEx1neTmfYKODRmARr/ld6pZTzBWYDfrCkiS1QB+3q3M08OQgYcLzs/vjW4epetDCmk0K1CEGcWdh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6jgld4oAppAOzvQ7eoIx2tbuuKVSdbJm65KDxl/T+boaYnjRm3omdETYnYRk3HAhrAeWpefX+dM/k7PrcheInnxHUyjzSzqlN03xYjg28kdda9FZJaVsQKqdEJ/St9ivXAAAAAZae/nTwyDn5u+4WkhZ76991cGB/ymyGpXziT0bwS86pRw/AcbpzXmzK+hq+kvrvpw==", + "kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6knkzSwcsialcheg69eZYPK8EzKRVI5FrRHKi8rgB+R5jyPV70ejmYEx1neTmfYKODRmARr/ld6pZTzBWYDfrCkiS1QB+3q3M08OQgYcLzs/vjW4epetDCmk0K1CEGcWdh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0kYYCAS8vM2T99GeCr4toQ+iQzvl5fI89mPrncYqx3C1d75BQbFk8LMtcnLWwntd6jgld4oAppAOzvQ7eoIx2tbuuKVSdbJm65KDxl/T+boaYnjRm3omdETYnYRk3HAhrAeWpefX+dM/k7PrcheInnxHUyjzSzqlN03xYjg28kdda9FZJaVsQKqdEJ/St9ivXAAAAAZae/nTwyDn5u+4WkhZ76991cGB/ymyGpXziT0bwS86pRw/AcbpzXmzK+hq+kvrvpwAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "sStVLdyxqInmv76iaNnRFB464lGq48iVeqYWSi2linE9DST0fTNhxSnvSXAoPpt8tFsanj5vPafC+ij/Fh98dOUlMbO42bf280pOZ4lm+zr63AWUpOOIugST+S6pq9zeB0OHp2NY8XFmriOEKhxeabhuV89ljqCDjlhXBeNZwM5zti4zg89Hd8TbKcw46jAsjIJe2Siw3Th7ELQQKR5ucX50f0GISmnOSceePPdvjbGJ8fSFOnSmSp8dK7uyehrU", "", true, }, { - "mY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpf+4uOyv3gPZe54SYGM4pfhteqJpwFQxdlpwXWyYxMTNaSLDj8VtSn/EJaSu+P6nFmWsda3mTYUPYMZzWE4hMqpDgFPcJhw3prArMThDPbR3Hx7E6NRAAR0LqcrdtsbDqu2T0tto1rpnFILdvHL4PqEUfTmF2mkM+DKj7lKwvvZUbukqBwLrnnbdfyqZJryzGAMIa2JvMEMYszGsYyiPXZvYx6Luk54oWOlOrwEKrCY4NMPwch6DbFq6KpnNSQwOmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpgRYCz7wpjk57X+NGJmo85tYKc+TNa1rT4/DxG9v6SHkpXmmPeHhzIIW8MOdkFjxB5o6Qn8Fa0c6Tt6br2gzkrGr1eK5/+RiIgEzVhcRrqdY/p7PLmKXqawrEvIv9QZ3AAAAAoo8rTzcIp5QvF3USzv2Lz99z43CPVkjHB1ejzj/SjzKNa54GiDzHoCoAL0xKLjRSqeL98AF0V1+cRI8FwJjOcMgf0gDmjzwiv3ppbPZKqJR7Go+57k02670lfG6s1MM0A==", + "mY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpf+4uOyv3gPZe54SYGM4pfhteqJpwFQxdlpwXWyYxMTNaSLDj8VtSn/EJaSu+P6nFmWsda3mTYUPYMZzWE4hMqpDgFPcJhw3prArMThDPbR3Hx7E6NRAAR0LqcrdtsbDqu2T0tto1rpnFILdvHL4PqEUfTmF2mkM+DKj7lKwvvZUbukqBwLrnnbdfyqZJryzGAMIa2JvMEMYszGsYyiPXZvYx6Luk54oWOlOrwEKrCY4NMPwch6DbFq6KpnNSQwOmY//hEITCBCZUJUN/wsOlw1iUSSOESL6PFSbN1abGK80t5jPNICNlPuSorio4mmWpgRYCz7wpjk57X+NGJmo85tYKc+TNa1rT4/DxG9v6SHkpXmmPeHhzIIW8MOdkFjxB5o6Qn8Fa0c6Tt6br2gzkrGr1eK5/+RiIgEzVhcRrqdY/p7PLmKXqawrEvIv9QZ3AAAAAoo8rTzcIp5QvF3USzv2Lz99z43CPVkjHB1ejzj/SjzKNa54GiDzHoCoAL0xKLjRSqeL98AF0V1+cRI8FwJjOcMgf0gDmjzwiv3ppbPZKqJR7Go+57k02670lfG6s1MM0AAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "g53N8ecorvG2sDgNv8D7quVhKMIIpdP9Bqk/8gmV5cJ5Rhk9gKvb4F0ll8J/ZZJVqa27OyciJwx6lym6QpVK9q1ASrqio7rD5POMDGm64Iay/ixXXn+//F+uKgDXADj9AySri2J1j3qEkqqe3kxKthw94DzAfUBPncHfTPazVtE48AfzB1KWZA7Vf/x/3phYs4ckcP7ZrdVViJVLbUgFy543dpKfEH2MD30ZLLYRhw8SatRCyIJuTZcMlluEKG+d", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEs=", true, }, { - "tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3sUmODSJXAQmAFBVnS2t+Xzf5ZCr1gCtMiJVjQ48/nob/SkrS4cTHHjbKIVS9cdD/BG/VDrZvBt/dPqXmdUFyFuTTMrViagR57YRrDmm1qm5LQ/A8VwUBdiArwgRQXH9jsYhgVmfcRAjJytrbYeR6ck4ZfmGr6x6akKiBLY4B1l9LaHTyz/6KSM5t8atpuR3HBJZfbBm2/K8nnYTl+mAU/EnIN3YQdUd65Hsd4Gtf6VT2qfz6hcrSgHutxR1usIL2tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3kyU9X4Kqjx6I6zYwVbn7PWbiy3OtY277z4ggIqW6AuDgzUeIyG9a4stMeQ07mOV/Ef4faj+eh4GJRKjJm7aUTYJCSAGY6klOXNoEzB54XF4EY5pkMPfW73SmxJi9B0aHAAAAEJGVg8trc1JcL8WfwX7A5FGZ7epiPqnQzrUxuiRSLUkGaLWBgwZusz3M8KN2QBqa/IIm0xOg40+xhjQxJduo4ACd2gHQa3+2G9L1hGIsziSuEjv1HfuP1sVw28u8W8JRWJIBLWGzDuj16M4Uag4qLSdAn3UhMTRwPQN+5kf26TTisoQK38r0gSCZ1EIDsOcDAavhjj+Z+/BPfWua2OBVxlJjNyxnafwr5BiE2H9OElh5GQBLnmB/emLOY6x5SGUANpPY9NiYvki/NgyRR/Cw4e+34Ifc4dMAIwgKmO/6+9uN+EQwPe23xGSWr0ZgBDbIH5bElW/Hfa0DAaVpd15G/JjZVDkn/iwF3l2EEeNmeMrlI8AFL5P//oprobFhfGQjJKW/cEP+nK1R+BORN3+iH/zLfw3Hp1pTzbb7tgiRWrXPKt9WknZ1oTDfFOuUl9wwaLg3PBFwxXebcMuFVjEuZYWOlW1P5UvE/KMoa/jSKbLbClJkodBDNaxslIdjzYCGM6Hgc5x1moKdljt5yGzWCHFxETgU/EKagOA6s8b+uuY8Goxl5gGsEb3Wasy6rwpHro3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKQ==", + "tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3sUmODSJXAQmAFBVnS2t+Xzf5ZCr1gCtMiJVjQ48/nob/SkrS4cTHHjbKIVS9cdD/BG/VDrZvBt/dPqXmdUFyFuTTMrViagR57YRrDmm1qm5LQ/A8VwUBdiArwgRQXH9jsYhgVmfcRAjJytrbYeR6ck4ZfmGr6x6akKiBLY4B1l9LaHTyz/6KSM5t8atpuR3HBJZfbBm2/K8nnYTl+mAU/EnIN3YQdUd65Hsd4Gtf6VT2qfz6hcrSgHutxR1usIL2tRpqHB4HADuHAUvHTcrzxmq1awdwEBA0GOJfebYTODyUqXBQ7FkYrz1oDvPyx5Z3kyU9X4Kqjx6I6zYwVbn7PWbiy3OtY277z4ggIqW6AuDgzUeIyG9a4stMeQ07mOV/Ef4faj+eh4GJRKjJm7aUTYJCSAGY6klOXNoEzB54XF4EY5pkMPfW73SmxJi9B0aHAAAAEJGVg8trc1JcL8WfwX7A5FGZ7epiPqnQzrUxuiRSLUkGaLWBgwZusz3M8KN2QBqa/IIm0xOg40+xhjQxJduo4ACd2gHQa3+2G9L1hGIsziSuEjv1HfuP1sVw28u8W8JRWJIBLWGzDuj16M4Uag4qLSdAn3UhMTRwPQN+5kf26TTisoQK38r0gSCZ1EIDsOcDAavhjj+Z+/BPfWua2OBVxlJjNyxnafwr5BiE2H9OElh5GQBLnmB/emLOY6x5SGUANpPY9NiYvki/NgyRR/Cw4e+34Ifc4dMAIwgKmO/6+9uN+EQwPe23xGSWr0ZgBDbIH5bElW/Hfa0DAaVpd15G/JjZVDkn/iwF3l2EEeNmeMrlI8AFL5P//oprobFhfGQjJKW/cEP+nK1R+BORN3+iH/zLfw3Hp1pTzbb7tgiRWrXPKt9WknZ1oTDfFOuUl9wwaLg3PBFwxXebcMuFVjEuZYWOlW1P5UvE/KMoa/jSKbLbClJkodBDNaxslIdjzYCGM6Hgc5x1moKdljt5yGzWCHFxETgU/EKagOA6s8b+uuY8Goxl5gGsEb3Wasy6rwpHro3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "lgFU4Jyo9GdHL7w31u3zXc8RQRnHVarZWNfd0lD45GvvQtwrZ1Y1OKB4T29a79UagPHOdk1S0k0hYAYQyyNAfRUzde1HP8R+2dms75gGZEnx2tXexEN+BVjRJfC8PR1lFJa6xvsEx5uSrOZzKmoMfCwcA55SMT5jFo4+KyWg2wP5OnFPx7XTdEKvf5YhpY0krQKiq3OUu79EwjNF1xV1+iLxx2KEIyK7RSYxO1BHrKOGOEzxSUK00MA+YVHe+DvW", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEtiLNj7hflFeVnNXPguxyoqkI/V7pGJtXBpH5N+RswQNA0b23aM33aH0HKHOWoGY/T/L7TQzYFGJ3vTLiXDFZg1OVqkGOMvqAgonOrHGi6IgcALyUMyCKlL5BQY23SeILJpYKolybJNwJfbjxpg0Oz+D2fr7r9XL1GMvgblu52bVQT1fR8uCRJfSsgA2OGw6k/MpKDCfMcjbR8jnZa8ROEvF4cohm7iV1788Vp2/2bdcEZRQSoaGV8pOmA9EkqzJVRABjkDso40fnQcm2IzjBUOsX+uFExVan56/vl9VZVwB0wnee3Uxiredn0kOayiPB16yimxXCDet+M+0UKjmIlmXYpkrCDrH0dn53w+U3OHqMQxPDnUpYBxadM1eI8xWFFxzaLkvega0q0DmEquyY02yiTqo+7Q4qaJVTLgu6/8ekzPxGKRi845NL8gRgaTtM3kidDzIQpyODZD0yeEZDY1M+3sUKHcVkhoxTQBTMyKJPc+M5DeBL3uaWMrvxuL6q8+X0xeBt+9kguPUNtIYqUgPAaXvM2i041bWHTJ0dZLyDJVOyzGaXRaF4mNkAuh4Et6Zw5PuOpMM2mI1oFKEZj7", true, }, { - "kY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEqvAtYaSs5qW3riOiiRFoLp7MThW4vCEhK0j8BZY5ZM/tnjB7mrLB59kGvzpW8PM/AoQRIWzyvO3Dxxfyj/UQcQRw+KakVRvrFca3Vy2K5cFwxYHwl6PFDM+OmGrlgOCoqZtY1SLOd+ovmFOODKiHBZzDZhC/lRfjKVy4LzI7AXDuFn4tlWoT7IsJyy6lYNaWFfLjYZPAsrv1gXJ1NYat5B6E0Pnz5C67u2Uigmlol2D91re3oAqIo+r8kiyFKOSBkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEooG0cMN47zQor6qj0owuxJjn5Ymrcd/FCQ1ud4cKoUlNaGWIekSjxJEB87elMy5oEUlUzVI9ObMm+2SE3Udgws7pkMM8fgQUQUqUVyc7sNCE9m/hQzlwtbXrNSS5Pb+6AAAAEaMO2hzDmr41cml4ktH+m9acCaUtck/ivOVANQi6qsQmhMOfvFgIzMwTqypsQVKWAKSOBQQCGv0o3lP8GJ5Y1FDEzH5wXwkPDEtNYRUkGUqD8dXaPGcZ+WNzT4KWqJlw36clSvUNFNDZKkKj7JPk/gK6MUBsavX/xzl+SOWYmxdu3Wd9rQm0yqNthoLKarQL9or9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwJRRD4g3oAXA2lVh0tgNRTyAvXfg1NOb4s6wX5YurLvawr0gTVZ6A0gRds3lPtjY14+8nB2MQrmYJfHQbvBWY745Q1GQqn3atz7M0HqNl+ebawyRB3lVmkaCHIIhtoX0zQ==", + "kY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEqvAtYaSs5qW3riOiiRFoLp7MThW4vCEhK0j8BZY5ZM/tnjB7mrLB59kGvzpW8PM/AoQRIWzyvO3Dxxfyj/UQcQRw+KakVRvrFca3Vy2K5cFwxYHwl6PFDM+OmGrlgOCoqZtY1SLOd+ovmFOODKiHBZzDZhC/lRfjKVy4LzI7AXDuFn4tlWoT7IsJyy6lYNaWFfLjYZPAsrv1gXJ1NYat5B6E0Pnz5C67u2Uigmlol2D91re3oAqIo+r8kiyFKOSBkY4NWaOoYItWtLKVQnxDh+XTsa0Yev5Ae3Q9vlQSKp6+IUtwS7GH5ZrZefmBEwWEooG0cMN47zQor6qj0owuxJjn5Ymrcd/FCQ1ud4cKoUlNaGWIekSjxJEB87elMy5oEUlUzVI9ObMm+2SE3Udgws7pkMM8fgQUQUqUVyc7sNCE9m/hQzlwtbXrNSS5Pb+6AAAAEaMO2hzDmr41cml4ktH+m9acCaUtck/ivOVANQi6qsQmhMOfvFgIzMwTqypsQVKWAKSOBQQCGv0o3lP8GJ5Y1FDEzH5wXwkPDEtNYRUkGUqD8dXaPGcZ+WNzT4KWqJlw36clSvUNFNDZKkKj7JPk/gK6MUBsavX/xzl+SOWYmxdu3Wd9rQm0yqNthoLKarQL9or9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwJRRD4g3oAXA2lVh0tgNRTyAvXfg1NOb4s6wX5YurLvawr0gTVZ6A0gRds3lPtjY14+8nB2MQrmYJfHQbvBWY745Q1GQqn3atz7M0HqNl+ebawyRB3lVmkaCHIIhtoX0zQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "jqPSA/XKqZDJnRSmM0sJxbrFv7GUcA45QMysIx1xTsI3+2iysF5Tr68565ZuO65qjo2lklZpQo+wtyKSA/56EaKOJZCZhSvDdBEdvVYJCjmWusuK5qav7xZO0w5W1qRiEgIdcGUz5V7JHqfRf4xI6/uUD846alyzzNjxQtKErqJbRw6yyBO6j6box363pinjiMTzU4w/qltzFuOEpKxy/H3vyH8RcsF24Ou/Rb6vfR7cSLtLwCsf/BMtPcsQfdRK", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEtiLNj7hflFeVnNXPguxyoqkI/V7pGJtXBpH5N+RswQNA0b23aM33aH0HKHOWoGY/T/L7TQzYFGJ3vTLiXDFZg1OVqkGOMvqAgonOrHGi6IgcALyUMyCKlL5BQY23SeILJpYKolybJNwJfbjxpg0Oz+D2fr7r9XL1GMvgblu52bVQT1fR8uCRJfSsgA2OGw6k/MpKDCfMcjbR8jnZa8ROEvF4cohm7iV1788Vp2/2bdcEZRQSoaGV8pOmA9EkqzJVRABjkDso40fnQcm2IzjBUOsX+uFExVan56/vl9VZVwB0wnee3Uxiredn0kOayiPB16yimxXCDet+M+0UKjmIlmXYpkrCDrH0dn53w+U3OHqMQxPDnUpYBxadM1eI8xWFFxzaLkvega0q0DmEquyY02yiTqo+7Q4qaJVTLgu6/8ekzPxGKRi845NL8gRgaTtM3kidDzIQpyODZD0yeEZDY1M+3sUKHcVkhoxTQBTMyKJPc+M5DeBL3uaWMrvxuL6q8+X0xeBt+9kguPUNtIYqUgPAaXvM2i041bWHTJ0dZLyDJVOyzGaXRaF4mNkAuh4Et6Zw5PuOpMM2mI1oFKEZj7Xqf/yAmy/Le3GfJnMg5vNgE7QxmVsjuKUP28iN8rdi4=", true, }, { - "pQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNupQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNuhKgxmPN2c94UBtlYc0kZS6CwyMEEV/nVGSjajEZPdnpbK7fEcPd0hWNcOxKWq8qBBPfT69Ore74buf8C26ZTyKnjgMsGCvoDAMOsA07DjjQ1nIkkwIGFFUT3iMO83TdEpWgV/2z7WT9axNH/QFPOjXvwQJFnC7hLxHnX6pgKOdAaioKdi6FX3Y2SwWEO3UuxFd3KwsrZ2+mma/W3KP/cPpSzqyHa5VaJwOCw6vSM4wHSGKmDF4TSrrnMxzIYiTbTpQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNulrwLi5GjMxD6BKzMMN9+7xFuO7txLCEIhGrIMFIvqTw1QFAO4rmAgyG+ljlYTfWHAkzqvImL1o8dMHhGOTsMLLMg39KsZVqalZwwL3ckpdAf81OJJeWCpCuaSgSXnWhJAAAAEph8ULgPc1Ia5pUdcBzvXnoB4f6dNaLD9MVNN62NaBqJzmdvGnGBujEjn2QZCk/jaKjnBFrS+EQj+rewVlx4CFJpYQhI/6cDVcfdlXN2cxPMzId1NfeiAh800mc9KzMCZJk9JdZu0HbwaalHqgscl4GPumn6rHQHRo2XrlDjwdkQ2ptwpto9meVcoL3SASNdqpSKBAYZ64QscekzfssIpXyNmgY807Z9KwnuyAPbGLXGJ910qKFO0wTQd/TvHGxoJ5hmEMoMQbEPxJo9igwqkOANTEZ0nt6urUIY06Kg4x0VxCs5VpGv+PoVjZyaYnKvy5k948Qh/f8q3vKhVF8vh6tsnIrY7966IMPocl5St6SKEJg7JCZ6gZN4cYrI90EK0Ir9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwA==", + "pQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNupQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNuhKgxmPN2c94UBtlYc0kZS6CwyMEEV/nVGSjajEZPdnpbK7fEcPd0hWNcOxKWq8qBBPfT69Ore74buf8C26ZTyKnjgMsGCvoDAMOsA07DjjQ1nIkkwIGFFUT3iMO83TdEpWgV/2z7WT9axNH/QFPOjXvwQJFnC7hLxHnX6pgKOdAaioKdi6FX3Y2SwWEO3UuxFd3KwsrZ2+mma/W3KP/cPpSzqyHa5VaJwOCw6vSM4wHSGKmDF4TSrrnMxzIYiTbTpQUlLSBu9HmVa9hB0rEu1weeBv2RKQQ8yCHpwXTHeSkcQqmSOuzednF8o0+MdyNulrwLi5GjMxD6BKzMMN9+7xFuO7txLCEIhGrIMFIvqTw1QFAO4rmAgyG+ljlYTfWHAkzqvImL1o8dMHhGOTsMLLMg39KsZVqalZwwL3ckpdAf81OJJeWCpCuaSgSXnWhJAAAAEph8ULgPc1Ia5pUdcBzvXnoB4f6dNaLD9MVNN62NaBqJzmdvGnGBujEjn2QZCk/jaKjnBFrS+EQj+rewVlx4CFJpYQhI/6cDVcfdlXN2cxPMzId1NfeiAh800mc9KzMCZJk9JdZu0HbwaalHqgscl4GPumn6rHQHRo2XrlDjwdkQ2ptwpto9meVcoL3SASNdqpSKBAYZ64QscekzfssIpXyNmgY807Z9KwnuyAPbGLXGJ910qKFO0wTQd/TvHGxoJ5hmEMoMQbEPxJo9igwqkOANTEZ0nt6urUIY06Kg4x0VxCs5VpGv+PoVjZyaYnKvy5k948Qh/f8q3vKhVF8vh6tsnIrY7966IMPocl5St6SKEJg7JCZ6gZN4cYrI90EK0Ir9Oj7m8kmOZUSBGJQo6+Y/AIZfYBbzfttnCIdYhorsAcoT4xg4D+Ye/MWVwgaCXpBGD3CNgtC7QXIeaWQvyaUtZBfZQS53aHYSJbrRK95pCqIUfg/3MzfxU3cVNm/NkKn2Th3Puq79m4hF8vAHaADMpI9XbCMO/3eTPF3ID4lMzZKOB+qdNkbdkdNTDmG6E6IGkB/JclOqPHPojkURhKGQ06uIbQvkGuwF06Hb0pht9yK8CVRjigzLb1iNVWHYVrN9kgdFtXfgaxW9DmwRrM/lJ+z/lfKnqwjKrvdOgZG43VprCmykARQvP2A03UovdqyKtSElEFp/PAIFv6vruij8cm1ORGYGwPhGcAgwejMgTYR3KwL1RXl/pI9UWNRsdZMwhN5XbE9+7Am2shbcjDGy+oA0AqE2nSV/44bPcIKdHWbo8DpNFnn4YMtQVB15f6vtp1wCj7yppYulqO/6WK/6tdxnLI+2e1kilZ+BZuF35CQ+tquqWgsTudQZSUBHJ6TTyku/s44ZkJU0YhK8g/L3uykM5NtHm+E4CDEdYSOaZ0Joxnk+esWckqdpw52A7KrJ1webkGPJcn+iGAvzx8xG960sfdZNGRLucOSDK1SvKLTc2R61LjNGj3SJqS0CeKhIL5nszkaXFAquEkafWxpd/8s1xObVmYJ90OpF8oxTIbvn6E2MtTVfhyWySNZ2DI3k693/kcUqYSGFsjGe7A90YA80ZOYkKg9SfvK3TiGZYjm365lmq6PwQcTb3dXzwAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "qV2FNaBFqWeL6n9q9OUbCSTcIQvwO0vfaA/f/SxEtLSIaOGIOx8r+WVGFdxmC6i3oOaoEkJWvML7PpKBDtqiK7pKDIaMV5PkV/kQl6UgxZv9OInTwpVPtYcgeeTokG/eBi1qKzJwDoEHVqKeLqrLXJHXhBVQLdoIUOeKj8YMkagVniO9EtK0fW0/9QnRIxXoilxSj5HBEpYwFBitJXRk1ftFGWZFxJXU5PXdRmC+pomyo5Scx+UJQ2NLRWHjKlV0", "aZ8tqrOeEJKt4AMqiRF/WJhIKTDC0HeDTgiJVLZ8OEtiLNj7hflFeVnNXPguxyoqkI/V7pGJtXBpH5N+RswQNA0b23aM33aH0HKHOWoGY/T/L7TQzYFGJ3vTLiXDFZg1OVqkGOMvqAgonOrHGi6IgcALyUMyCKlL5BQY23SeILJpYKolybJNwJfbjxpg0Oz+D2fr7r9XL1GMvgblu52bVQT1fR8uCRJfSsgA2OGw6k/MpKDCfMcjbR8jnZa8ROEvF4cohm7iV1788Vp2/2bdcEZRQSoaGV8pOmA9EkqzJVRABjkDso40fnQcm2IzjBUOsX+uFExVan56/vl9VZVwB0wnee3Uxiredn0kOayiPB16yimxXCDet+M+0UKjmIlmXYpkrCDrH0dn53w+U3OHqMQxPDnUpYBxadM1eI8xWFFxzaLkvega0q0DmEquyY02yiTqo+7Q4qaJVTLgu6/8ekzPxGKRi845NL8gRgaTtM3kidDzIQpyODZD0yeEZDY1M+3sUKHcVkhoxTQBTMyKJPc+M5DeBL3uaWMrvxuL6q8+X0xeBt+9kguPUNtIYqUgPAaXvM2i041bWHTJ0dZLyDJVOyzGaXRaF4mNkAuh4Et6Zw5PuOpMM2mI1oFKEZj7Xqf/yAmy/Le3GfJnMg5vNgE7QxmVsjuKUP28iN8rdi4bUp7c0KJpqLXE6evfRrdZBDRYp+rmOLLDg55ggNuwog==", true, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywQ==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "jiGBK+TGHfH8Oadexhdet7ExyIWibSmamWQvffZkyl3WnMoVbTQ3lOks4Mca3sU5qgcaLyQQ1FjFW4g6vtoMapZ43hTGKaWO7bQHsOCvdwHCdwJDulVH16cMTyS9F0BfBJxa88F+JKZc4qMTJjQhspmq755SrKhN9Jf+7uPUhgB4hJTSrmlOkTatgW+/HAf5kZKhv2oRK5p5kS4sU48oqlG1azhMtcHEXDQdcwf9ANel4Z9cb+MQyp2RzI/3hlIx", "", false, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAo3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOQ==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAAo3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOQAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "hp1iMepdu0rKoBh0NXcw9F9hkiggDIkRNINq2rlvUypPiSmp8U8tDSMeG0YVSovFteecr3THhBJj0qNeEe9jA2Ci64fKG9WT1heMYzEAQKebOErYXYCm9d72n97mYn1XBq+g1Y730XEDv4BIDI1hBDntJcgcj/cSvcILB1+60axJvtyMyuizxUr1JUBUq9njtmJ9m8zK6QZLNqMiKh0f2jokQb5mVhu6v5guW3KIjwQc/oFK/l5ehKAOPKUUggNh", "c9BSUPtO0xjPxWVNkEMfXe7O4UZKpaH/nLIyQJj7iA4=", false, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEI3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1A==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEI3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1AAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "pNeWbxzzJPMsPpuXBXWZgtLic1s0KL8UeLDGBhEjygrv8m1eMM12pzd+r/scvBEHrnEoQHanlNTlWPywaXaFtB5Hd5RMrnbfLbpe16tvtlH2SRbJbGXSpib5uiuSa6z1ExLtXs9nNWiu10eupG6Pq4SNOacCEVvUgSzCzhyLIlz62gq4DlBBWKmEFI7KiFs7kr2EPBjj2m83dbA/GGVgoYYjgBmFX6/srvLADxerZTKG2moOQrmAx9GJ99nwhRbW", "I8C5RcBDPi2n4omt9oOV2rZk9T9xlSV8PQvLeVHjGb00fCVz7AHOIjLJ03ZCTLQwEKkAk9tQWJ6gFTBnG2+0DDHlXcVkwpMafcpS2diKFe0T4fRb0t9mxNzOFiRVcJoeMU1zb/rE4dIMm9rbEPSDnVSOd8tHNnJDkT+/NcNsQ2w0UEVJJRAEnC7G0Y3522RlDLxpTZ6w0U/9V0pLNkFgDCkFBKvpaEfPDJjoEVyCUWDC1ts9LIR43xh3ZZBdcO/HATHoLzxM3Ef11qF+riV7WDPEJfK11u8WGazzCAFhsx0aKkkbnKl7LnypBzwRvrG2JxdLI/oXL0eoIw9woVjqrg6elHudnHDXezDVXjRWMPaU+L3tOW9aqN+OdP4AhtpgT2CoRCjrOIU3MCFqsrCK9bh33PW1gtNeHC78mIetQM5LWZHtw4KNwafTrQ+GCKPelJhiC2x7ygBtat5rtBsJAVF5wjssLPZx/7fqNqifXB7WyMV7J1M8LBQVXj5kLoS9bpmNHlERRSadC0DEUbY9xhIG2xo7R88R0sq04a299MFv8XJNd+IdueYiMiGF5broHD4UUhPxRBlBO3lOfDTPnRSUGS3Sr6GxwCjKO3MObz/6RNxCk9SnQ4NccD17hS/m", false, }, { - "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1KTXlm8c8yTzLD6blwV1mYLS4nNbNCi/FHiwxgYRI8oK7/JtXjDNdqc3fq/7HLwRBw==", + "lp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nh7yLzdqr7HHQNOpZI8mdj/7lR0IBqB9zvRfyTr+guUG22kZo4y2KINDp272xGglKEeTglTxyDUriZJNF/+T6F8w70MR/rV+flvuo6EJ0+HA+A2ZnBbTjOIl9wjisBV+0iISo2JdNY1vPXlpwhlL2fVpW/WlREkF0bKlBadDIbNJBgM4niJGuEZDru3wqrGueETKHPv7hQ8em+p6vQolp7c0iknjXrGnvlpf4QtUtpg3z/D+snWjRPbVqRgKXWtihlp7+dPDIOfm77haSFnvr33VwYH/KbIalfOJPRvBLzqlHD8BxunNebMr6Gr6S+u+nuIvPFaM6dt7HZEbkeMnXWwSINeYC/j3lqYnce8Jq+XkuF42stVNiooI+TuXECnFdFi9Ib25b9wtyz3H/oKg48He1ftntj5uIRCOBvzkFHGUF6Ty214v3JYvXJjdS4uS2AAAAEY3pKZWWKFdGlxJ2BDiL3xe3eMWst6pMdjbKaKDOB3maYh5JyFpFFYRlSQcwQy4ywY4hgSvkxh3x/DmnXsYXXrexMciFom0pmplkL332ZMpd1pzKFW00N5TpLODHGt7FOYadYjHqXbtKyqAYdDV3MPRfYZIoIAyJETSDatq5b1MqT4kpqfFPLQ0jHhtGFUqLxZQOA7IwcJ7SR+OTYDW2P0W7v4X3u0LJE5AYk6NgPpJmEh++VL39lAF8AQE9T6BNLKrWJ3Rdim7YZehspVd4/TSCrSMx3fxsJXhMbjOypSNR9tj+/G6+c5ofA+1dWhXMT6X7UIY3IeGm17sRD+GoYHzZrYXYaEr0MpC4Y25hvQa78Jsmg0Xf7Bwk3kl5i6iMm63erQPGSyatgy8yiRDpZnzPWiIoS3vxS53X500pDStnreKPLzN2IwJoDp6wKOySKZdGaLefTVIk3ceY9uyFfvLIkVkUc/VN77b8+NtaoDxOZJ+mKyUlZ2CgqFngdxTX6YdVfjIfPeKN0i3Y/Z+6IH5R/7H14rGI+b5XkjZXSYv+u0yjOAYCNWmhnV7k7Xh6irYuuq10PvWvXfkpsCoJ0VKY1btbZK1mQkhW1broGWBGfQWY4VQkmOt+sAbhuihb+7AyoomdL10aqVI1AjhTH5ExvZyXaDrWrY5YgHn+/g0197VE5dZlXTXM5gJxIHomSat5jCsXGyonDl0LHKlPyYHdfmNm7MkLAyIMDf5Nt8u4wLmhISD5THi8y/OCZJeTfLGwCId+al2c+7XrMmHbfBbiV+hgruqlyjhbPGhZ/EVdsfQWvM+YhwQsEu0DgpmZ2pMsFPy29pBRGqrANivFv92Q8NrVuZjUKi5R/zEaBqeEjC7OmtAijtj4dOd9qHj6Q5YEKBdZF/acn/VAUGjSH65FwxkBkv69sui2U3T4r2LOpfa+gEVMYrEUc6m3vFr8VaD2ib6/F4P3akFs9pWILQnYhlm47zVIQ2KSnc0fvL/CEXq2JR+i/EaaQ0YYgs0E1KTXlm8c8yTzLD6blwV1mYLS4nNbNCi/FHiwxgYRI8oK7/JtXjDNdqc3fq/7HLwRBwAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "iw5yhCCarVRq/h0Klq4tHNdF1j7PxaDn0AfHTxc2hb//Acav53QStwQShQ0BpQJ7sdchkTTJLkhM13+JpPY/I2WIc6DMZdRzw3pRjLSdMUmce7LYbBJOI+/IyuLZH5IXA7sX4r+xrPssIaMiKR3twmmReN9NrSoovLepDsNmzDVraO71B4rkx7uPXvkqvt3Zkr2EPBjj2m83dbA/GGVgoYYjgBmFX6/srvLADxerZTKG2moOQrmAx9GJ99nwhRbW", "I8C5RcBDPi2n4omt9oOV2rZk9T9xlSV8PQvLeVHjGb00fCVz7AHOIjLJ03ZCTLQwEKkAk9tQWJ6gFTBnG2+0DDHlXcVkwpMafcpS2diKFe0T4fRb0t9mxNzOFiRVcJoeMU1zb/rE4dIMm9rbEPSDnVSOd8tHNnJDkT+/NcNsQ2w0UEVJJRAEnC7G0Y3522RlDLxpTZ6w0U/9V0pLNkFgDCkFBKvpaEfPDJjoEVyCUWDC1ts9LIR43xh3ZZBdcO/HATHoLzxM3Ef11qF+riV7WDPEJfK11u8WGazzCAFhsx0aKkkbnKl7LnypBzwRvrG2JxdLI/oXL0eoIw9woVjqrg6elHudnHDXezDVXjRWMPaU+L3tOW9aqN+OdP4AhtpgT2CoRCjrOIU3MCFqsrCK9bh33PW1gtNeHC78mIetQM5LWZHtw4KNwafTrQ+GCKPelJhiC2x7ygBtat5rtBsJAVF5wjssLPZx/7fqNqifXB7WyMV7J1M8LBQVXj5kLoS9bpmNHlERRSadC0DEUbY9xhIG2xo7R88R0sq04a299MFv8XJNd+IdueYiMiGF5broHD4UUhPxRBlBO3lOfDTPnRSUGS3Sr6GxwCjKO3MObz/6RNxCk9SnQ4NccD17hS/mEFt8d4ERZOfmuvD3A0RCPCnx3Fr6rHdm6j+cfn/NM6o=", false, }, diff --git a/internal/backend/bls12-377/groth16/commitment.go b/backend/groth16/bls12-377/commitment.go similarity index 74% rename from internal/backend/bls12-377/groth16/commitment.go rename to backend/groth16/bls12-377/commitment.go index 72506760a4..c267e8a99b 100644 --- a/internal/backend/bls12-377/groth16/commitment.go +++ b/backend/groth16/bls12-377/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls12-377/groth16/commitment_test.go b/backend/groth16/bls12-377/commitment_test.go similarity index 91% rename from internal/backend/bls12-377/groth16/commitment_test.go rename to backend/groth16/bls12-377/commitment_test.go index 18f6f142d1..f18eefb837 100644 --- a/internal/backend/bls12-377/groth16/commitment_test.go +++ b/backend/groth16/bls12-377/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls12-377/groth16/marshal.go b/backend/groth16/bls12-377/marshal.go similarity index 81% rename from internal/backend/bls12-377/groth16/marshal.go rename to backend/groth16/bls12-377/marshal.go index 787d450165..bec919254c 100644 --- a/internal/backend/bls12-377/groth16/marshal.go +++ b/backend/groth16/bls12-377/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls12-377/groth16/marshal_test.go b/backend/groth16/bls12-377/marshal_test.go similarity index 78% rename from internal/backend/bls12-377/groth16/marshal_test.go rename to backend/groth16/bls12-377/marshal_test.go index 901d6f885f..b5fe628704 100644 --- a/internal/backend/bls12-377/groth16/marshal_test.go +++ b/backend/groth16/bls12-377/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls12-377/mpcsetup/lagrange.go b/backend/groth16/bls12-377/mpcsetup/lagrange.go new file mode 100644 index 0000000000..92b6acb948 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls12-377/mpcsetup/marshal.go b/backend/groth16/bls12-377/mpcsetup/marshal.go new file mode 100644 index 0000000000..35ceece58f --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls12-377/mpcsetup/marshal_test.go b/backend/groth16/bls12-377/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..ab7e6956cc --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls12-377/mpcsetup/phase1.go b/backend/groth16/bls12-377/mpcsetup/phase1.go new file mode 100644 index 0000000000..573aaec996 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-377/mpcsetup/phase2.go b/backend/groth16/bls12-377/mpcsetup/phase2.go new file mode 100644 index 0000000000..e3816d65ca --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls12-377" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-377/mpcsetup/setup.go b/backend/groth16/bls12-377/mpcsetup/setup.go new file mode 100644 index 0000000000..683369c871 --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls12-377" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls12-377/mpcsetup/setup_test.go b/backend/groth16/bls12-377/mpcsetup/setup_test.go new file mode 100644 index 0000000000..ca8cca346f --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + cs "github.com/consensys/gnark/constraint/bls12-377" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls12-377/mpcsetup/utils.go b/backend/groth16/bls12-377/mpcsetup/utils.go new file mode 100644 index 0000000000..978b2ecbde --- /dev/null +++ b/backend/groth16/bls12-377/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls12-377/groth16/prove.go b/backend/groth16/bls12-377/prove.go similarity index 62% rename from internal/backend/bls12-377/groth16/prove.go rename to backend/groth16/bls12-377/prove.go index 4590530c00..ed7124a557 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/backend/groth16/bls12-377/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bls12-377" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls12-377" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls12-377/groth16/setup.go b/backend/groth16/bls12-377/setup.go similarity index 75% rename from internal/backend/bls12-377/groth16/setup.go rename to backend/groth16/bls12-377/setup.go index 8b3beb485e..393d6a802e 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/backend/groth16/bls12-377/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls12-377/groth16/verify.go b/backend/groth16/bls12-377/verify.go similarity index 64% rename from internal/backend/bls12-377/groth16/verify.go rename to backend/groth16/bls12-377/verify.go index 6a9819c093..3da51fcaee 100644 --- a/internal/backend/bls12-377/groth16/verify.go +++ b/backend/groth16/bls12-377/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bls12-381/groth16/commitment.go b/backend/groth16/bls12-381/commitment.go similarity index 74% rename from internal/backend/bls12-381/groth16/commitment.go rename to backend/groth16/bls12-381/commitment.go index 680f862419..6fcc533b5b 100644 --- a/internal/backend/bls12-381/groth16/commitment.go +++ b/backend/groth16/bls12-381/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls12-381/groth16/commitment_test.go b/backend/groth16/bls12-381/commitment_test.go similarity index 91% rename from internal/backend/bls12-381/groth16/commitment_test.go rename to backend/groth16/bls12-381/commitment_test.go index 9279f8449e..8f6882aeb1 100644 --- a/internal/backend/bls12-381/groth16/commitment_test.go +++ b/backend/groth16/bls12-381/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls12-381/groth16/marshal.go b/backend/groth16/bls12-381/marshal.go similarity index 81% rename from internal/backend/bls12-381/groth16/marshal.go rename to backend/groth16/bls12-381/marshal.go index 8f1cf81f4a..5a51575c78 100644 --- a/internal/backend/bls12-381/groth16/marshal.go +++ b/backend/groth16/bls12-381/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls12-381/groth16/marshal_test.go b/backend/groth16/bls12-381/marshal_test.go similarity index 78% rename from internal/backend/bls12-381/groth16/marshal_test.go rename to backend/groth16/bls12-381/marshal_test.go index 990e3ee6b1..55d8c0856f 100644 --- a/internal/backend/bls12-381/groth16/marshal_test.go +++ b/backend/groth16/bls12-381/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls12-381/mpcsetup/lagrange.go b/backend/groth16/bls12-381/mpcsetup/lagrange.go new file mode 100644 index 0000000000..8da7a42f2b --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls12-381/mpcsetup/marshal.go b/backend/groth16/bls12-381/mpcsetup/marshal.go new file mode 100644 index 0000000000..0aa64ea1f0 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls12-381/mpcsetup/marshal_test.go b/backend/groth16/bls12-381/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..bbcaa65d36 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls12-381/mpcsetup/phase1.go b/backend/groth16/bls12-381/mpcsetup/phase1.go new file mode 100644 index 0000000000..14cef5f605 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-381/mpcsetup/phase2.go b/backend/groth16/bls12-381/mpcsetup/phase2.go new file mode 100644 index 0000000000..ed42a69f9c --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls12-381" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls12-381/mpcsetup/setup.go b/backend/groth16/bls12-381/mpcsetup/setup.go new file mode 100644 index 0000000000..dd568aa21e --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls12-381" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls12-381/mpcsetup/setup_test.go b/backend/groth16/bls12-381/mpcsetup/setup_test.go new file mode 100644 index 0000000000..0e9880b010 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + cs "github.com/consensys/gnark/constraint/bls12-381" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls12-381/mpcsetup/utils.go b/backend/groth16/bls12-381/mpcsetup/utils.go new file mode 100644 index 0000000000..e29ec7ae32 --- /dev/null +++ b/backend/groth16/bls12-381/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls12-381/groth16/prove.go b/backend/groth16/bls12-381/prove.go similarity index 62% rename from internal/backend/bls12-381/groth16/prove.go rename to backend/groth16/bls12-381/prove.go index 9afd4240c1..6e4c0a5227 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/backend/groth16/bls12-381/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bls12-381" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls12-381" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls12-381/groth16/setup.go b/backend/groth16/bls12-381/setup.go similarity index 75% rename from internal/backend/bls12-381/groth16/setup.go rename to backend/groth16/bls12-381/setup.go index f1b05f3cc8..b5333ba374 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/backend/groth16/bls12-381/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls12-381/groth16/verify.go b/backend/groth16/bls12-381/verify.go similarity index 64% rename from internal/backend/bls12-381/groth16/verify.go rename to backend/groth16/bls12-381/verify.go index 0a9de8c797..646e052f42 100644 --- a/internal/backend/bls12-381/groth16/verify.go +++ b/backend/groth16/bls12-381/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bls24-315/groth16/commitment.go b/backend/groth16/bls24-315/commitment.go similarity index 74% rename from internal/backend/bls24-315/groth16/commitment.go rename to backend/groth16/bls24-315/commitment.go index 3cb966f8f0..fc1a3def96 100644 --- a/internal/backend/bls24-315/groth16/commitment.go +++ b/backend/groth16/bls24-315/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls24-315/groth16/commitment_test.go b/backend/groth16/bls24-315/commitment_test.go similarity index 91% rename from internal/backend/bls24-315/groth16/commitment_test.go rename to backend/groth16/bls24-315/commitment_test.go index 236bcda605..0f626448b3 100644 --- a/internal/backend/bls24-315/groth16/commitment_test.go +++ b/backend/groth16/bls24-315/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls24-315/groth16/marshal.go b/backend/groth16/bls24-315/marshal.go similarity index 81% rename from internal/backend/bls24-315/groth16/marshal.go rename to backend/groth16/bls24-315/marshal.go index ede9181c06..3da8e55772 100644 --- a/internal/backend/bls24-315/groth16/marshal.go +++ b/backend/groth16/bls24-315/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls24-315/groth16/marshal_test.go b/backend/groth16/bls24-315/marshal_test.go similarity index 78% rename from internal/backend/bls24-315/groth16/marshal_test.go rename to backend/groth16/bls24-315/marshal_test.go index 06bf1ec9de..ecdd8b4ea1 100644 --- a/internal/backend/bls24-315/groth16/marshal_test.go +++ b/backend/groth16/bls24-315/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls24-315/mpcsetup/lagrange.go b/backend/groth16/bls24-315/mpcsetup/lagrange.go new file mode 100644 index 0000000000..6fb61f2c68 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls24-315/mpcsetup/marshal.go b/backend/groth16/bls24-315/mpcsetup/marshal.go new file mode 100644 index 0000000000..f670a7af74 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls24-315/mpcsetup/marshal_test.go b/backend/groth16/bls24-315/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..02d61df6a9 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls24-315/mpcsetup/phase1.go b/backend/groth16/bls24-315/mpcsetup/phase1.go new file mode 100644 index 0000000000..cefde6c90f --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-315/mpcsetup/phase2.go b/backend/groth16/bls24-315/mpcsetup/phase2.go new file mode 100644 index 0000000000..48939131d3 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls24-315" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-315/mpcsetup/setup.go b/backend/groth16/bls24-315/mpcsetup/setup.go new file mode 100644 index 0000000000..98fc63f1ad --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls24-315" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls24-315/mpcsetup/setup_test.go b/backend/groth16/bls24-315/mpcsetup/setup_test.go new file mode 100644 index 0000000000..25c8affc68 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + cs "github.com/consensys/gnark/constraint/bls24-315" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls24-315/mpcsetup/utils.go b/backend/groth16/bls24-315/mpcsetup/utils.go new file mode 100644 index 0000000000..c86248ac16 --- /dev/null +++ b/backend/groth16/bls24-315/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls24-315/groth16/prove.go b/backend/groth16/bls24-315/prove.go similarity index 62% rename from internal/backend/bls24-315/groth16/prove.go rename to backend/groth16/bls24-315/prove.go index cd2fa623e5..c464544ad0 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/backend/groth16/bls24-315/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bls24-315" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls24-315" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls24-315/groth16/setup.go b/backend/groth16/bls24-315/setup.go similarity index 75% rename from internal/backend/bls24-315/groth16/setup.go rename to backend/groth16/bls24-315/setup.go index 8a80f7f023..6a8c8e60d2 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/backend/groth16/bls24-315/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls24-315/groth16/verify.go b/backend/groth16/bls24-315/verify.go similarity index 64% rename from internal/backend/bls24-315/groth16/verify.go rename to backend/groth16/bls24-315/verify.go index f0f170d9b3..6e85b70ecf 100644 --- a/internal/backend/bls24-315/groth16/verify.go +++ b/backend/groth16/bls24-315/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bls24-317/groth16/commitment.go b/backend/groth16/bls24-317/commitment.go similarity index 74% rename from internal/backend/bls24-317/groth16/commitment.go rename to backend/groth16/bls24-317/commitment.go index 00b3713dd6..05d71ba172 100644 --- a/internal/backend/bls24-317/groth16/commitment.go +++ b/backend/groth16/bls24-317/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bls24-317/groth16/commitment_test.go b/backend/groth16/bls24-317/commitment_test.go similarity index 91% rename from internal/backend/bls24-317/groth16/commitment_test.go rename to backend/groth16/bls24-317/commitment_test.go index 3963ef349e..bfb2fc578e 100644 --- a/internal/backend/bls24-317/groth16/commitment_test.go +++ b/backend/groth16/bls24-317/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bls24-317/groth16/marshal.go b/backend/groth16/bls24-317/marshal.go similarity index 81% rename from internal/backend/bls24-317/groth16/marshal.go rename to backend/groth16/bls24-317/marshal.go index 3b0ec45426..75deea5847 100644 --- a/internal/backend/bls24-317/groth16/marshal.go +++ b/backend/groth16/bls24-317/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bls24-317/groth16/marshal_test.go b/backend/groth16/bls24-317/marshal_test.go similarity index 78% rename from internal/backend/bls24-317/groth16/marshal_test.go rename to backend/groth16/bls24-317/marshal_test.go index aa75c1d456..860e08725c 100644 --- a/internal/backend/bls24-317/groth16/marshal_test.go +++ b/backend/groth16/bls24-317/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bls24-317/mpcsetup/lagrange.go b/backend/groth16/bls24-317/mpcsetup/lagrange.go new file mode 100644 index 0000000000..9d0ba07367 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bls24-317/mpcsetup/marshal.go b/backend/groth16/bls24-317/mpcsetup/marshal.go new file mode 100644 index 0000000000..6e1ebfc02c --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bls24-317/mpcsetup/marshal_test.go b/backend/groth16/bls24-317/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..ddc7229489 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bls24-317/mpcsetup/phase1.go b/backend/groth16/bls24-317/mpcsetup/phase1.go new file mode 100644 index 0000000000..72e61e9977 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-317/mpcsetup/phase2.go b/backend/groth16/bls24-317/mpcsetup/phase2.go new file mode 100644 index 0000000000..d3037cc3d3 --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls24-317" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bls24-317/mpcsetup/setup.go b/backend/groth16/bls24-317/mpcsetup/setup.go new file mode 100644 index 0000000000..f90fea816c --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bls24-317" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bls24-317/mpcsetup/setup_test.go b/backend/groth16/bls24-317/mpcsetup/setup_test.go new file mode 100644 index 0000000000..750ab1e0cf --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + cs "github.com/consensys/gnark/constraint/bls24-317" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bls24-317/mpcsetup/utils.go b/backend/groth16/bls24-317/mpcsetup/utils.go new file mode 100644 index 0000000000..877fef7fad --- /dev/null +++ b/backend/groth16/bls24-317/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bls24-317/groth16/prove.go b/backend/groth16/bls24-317/prove.go similarity index 62% rename from internal/backend/bls24-317/groth16/prove.go rename to backend/groth16/bls24-317/prove.go index 5dfeaf9228..10f38c5f77 100644 --- a/internal/backend/bls24-317/groth16/prove.go +++ b/backend/groth16/bls24-317/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bls24-317" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls24-317" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bls24-317/groth16/setup.go b/backend/groth16/bls24-317/setup.go similarity index 75% rename from internal/backend/bls24-317/groth16/setup.go rename to backend/groth16/bls24-317/setup.go index c49e1718ea..68ee5c8922 100644 --- a/internal/backend/bls24-317/groth16/setup.go +++ b/backend/groth16/bls24-317/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bls24-317/groth16/verify.go b/backend/groth16/bls24-317/verify.go similarity index 64% rename from internal/backend/bls24-317/groth16/verify.go rename to backend/groth16/bls24-317/verify.go index e6f5f33f46..3affc23113 100644 --- a/internal/backend/bls24-317/groth16/verify.go +++ b/backend/groth16/bls24-317/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bn254/groth16/commitment.go b/backend/groth16/bn254/commitment.go similarity index 74% rename from internal/backend/bn254/groth16/commitment.go rename to backend/groth16/bn254/commitment.go index d3930c0f08..435a7c058c 100644 --- a/internal/backend/bn254/groth16/commitment.go +++ b/backend/groth16/bn254/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bn254/groth16/commitment_test.go b/backend/groth16/bn254/commitment_test.go similarity index 91% rename from internal/backend/bn254/groth16/commitment_test.go rename to backend/groth16/bn254/commitment_test.go index 4573f9749b..759501ebe2 100644 --- a/internal/backend/bn254/groth16/commitment_test.go +++ b/backend/groth16/bn254/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bn254/groth16/marshal.go b/backend/groth16/bn254/marshal.go similarity index 82% rename from internal/backend/bn254/groth16/marshal.go rename to backend/groth16/bn254/marshal.go index 66e3c144a1..0a310d6d15 100644 --- a/internal/backend/bn254/groth16/marshal.go +++ b/backend/groth16/bn254/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bn254/groth16/marshal_test.go b/backend/groth16/bn254/marshal_test.go similarity index 78% rename from internal/backend/bn254/groth16/marshal_test.go rename to backend/groth16/bn254/marshal_test.go index 2db9e45415..ff9d807805 100644 --- a/internal/backend/bn254/groth16/marshal_test.go +++ b/backend/groth16/bn254/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bn254/mpcsetup/lagrange.go b/backend/groth16/bn254/mpcsetup/lagrange.go new file mode 100644 index 0000000000..cbf377dc25 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bn254/mpcsetup/marshal.go b/backend/groth16/bn254/mpcsetup/marshal.go new file mode 100644 index 0000000000..08cb2ae3d1 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bn254/mpcsetup/marshal_test.go b/backend/groth16/bn254/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..ce03c11bd4 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bn254/mpcsetup/phase1.go b/backend/groth16/bn254/mpcsetup/phase1.go new file mode 100644 index 0000000000..a912a473aa --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bn254/mpcsetup/phase2.go b/backend/groth16/bn254/mpcsetup/phase2.go new file mode 100644 index 0000000000..3fcafb30da --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bn254" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bn254/mpcsetup/setup.go b/backend/groth16/bn254/mpcsetup/setup.go new file mode 100644 index 0000000000..4946e9f597 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bn254" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bn254/mpcsetup/setup_test.go b/backend/groth16/bn254/mpcsetup/setup_test.go new file mode 100644 index 0000000000..63b717cac4 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/setup_test.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + cs "github.com/consensys/gnark/constraint/bn254" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bn254/mpcsetup/utils.go b/backend/groth16/bn254/mpcsetup/utils.go new file mode 100644 index 0000000000..e3b47d1121 --- /dev/null +++ b/backend/groth16/bn254/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bn254/groth16/prove.go b/backend/groth16/bn254/prove.go similarity index 62% rename from internal/backend/bn254/groth16/prove.go rename to backend/groth16/bn254/prove.go index 8ca7d568d8..42ec4de8b9 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/backend/groth16/bn254/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bn254/groth16/setup.go b/backend/groth16/bn254/setup.go similarity index 75% rename from internal/backend/bn254/groth16/setup.go rename to backend/groth16/bn254/setup.go index e13869bbcb..372c723da0 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/backend/groth16/bn254/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bn254/groth16/solidity.go b/backend/groth16/bn254/solidity.go similarity index 100% rename from internal/backend/bn254/groth16/solidity.go rename to backend/groth16/bn254/solidity.go diff --git a/backend/groth16/bn254/utils_test.go b/backend/groth16/bn254/utils_test.go new file mode 100644 index 0000000000..a5bc2a5770 --- /dev/null +++ b/backend/groth16/bn254/utils_test.go @@ -0,0 +1,38 @@ +package groth16 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/assert" +) + +func assertSliceEquals[T any](t *testing.T, expected []T, seen []T) { + assert.Equal(t, len(expected), len(seen)) + for i := range expected { + assert.Equal(t, expected[i], seen[i]) + } +} + +func TestFilterHeap(t *testing.T) { + elems := []fr.Element{{0}, {1}, {2}, {3}} + + r := filterHeap(elems, 0, []int{1, 2}) + expected := []fr.Element{{0}, {3}} + assertSliceEquals(t, expected, r) + + r = filterHeap(elems[1:], 1, []int{1, 2}) + expected = []fr.Element{{3}} + assertSliceEquals(t, expected, r) +} + +func TestFilterRepeated(t *testing.T) { + elems := []fr.Element{{0}, {1}, {2}, {3}} + r := filterHeap(elems, 0, []int{1, 1, 2}) + expected := []fr.Element{{0}, {3}} + assertSliceEquals(t, expected, r) + + r = filterHeap(elems[1:], 1, []int{1, 1, 2}) + expected = []fr.Element{{3}} + assertSliceEquals(t, expected, r) +} diff --git a/internal/backend/bn254/groth16/verify.go b/backend/groth16/bn254/verify.go similarity index 69% rename from internal/backend/bn254/groth16/verify.go rename to backend/groth16/bn254/verify.go index 917ebf9166..64c5073266 100644 --- a/internal/backend/bn254/groth16/verify.go +++ b/backend/groth16/bn254/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "text/template" "time" ) @@ -37,10 +39,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -63,21 +63,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -88,8 +99,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bw6-633/groth16/commitment.go b/backend/groth16/bw6-633/commitment.go similarity index 74% rename from internal/backend/bw6-633/groth16/commitment.go rename to backend/groth16/bw6-633/commitment.go index d1243f342c..f8af92e4fc 100644 --- a/internal/backend/bw6-633/groth16/commitment.go +++ b/backend/groth16/bw6-633/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bw6-633/groth16/commitment_test.go b/backend/groth16/bw6-633/commitment_test.go similarity index 91% rename from internal/backend/bw6-633/groth16/commitment_test.go rename to backend/groth16/bw6-633/commitment_test.go index b7f700d2cb..a832752ffd 100644 --- a/internal/backend/bw6-633/groth16/commitment_test.go +++ b/backend/groth16/bw6-633/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bw6-633/groth16/marshal.go b/backend/groth16/bw6-633/marshal.go similarity index 82% rename from internal/backend/bw6-633/groth16/marshal.go rename to backend/groth16/bw6-633/marshal.go index 8c09e3a403..edca7ba642 100644 --- a/internal/backend/bw6-633/groth16/marshal.go +++ b/backend/groth16/bw6-633/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bw6-633/groth16/marshal_test.go b/backend/groth16/bw6-633/marshal_test.go similarity index 78% rename from internal/backend/bw6-633/groth16/marshal_test.go rename to backend/groth16/bw6-633/marshal_test.go index b2a3a63d4b..9e1c0f0728 100644 --- a/internal/backend/bw6-633/groth16/marshal_test.go +++ b/backend/groth16/bw6-633/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bw6-633/mpcsetup/lagrange.go b/backend/groth16/bw6-633/mpcsetup/lagrange.go new file mode 100644 index 0000000000..563daf9891 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bw6-633/mpcsetup/marshal.go b/backend/groth16/bw6-633/mpcsetup/marshal.go new file mode 100644 index 0000000000..4955ec6d92 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bw6-633/mpcsetup/marshal_test.go b/backend/groth16/bw6-633/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..b9d346f3bd --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bw6-633/mpcsetup/phase1.go b/backend/groth16/bw6-633/mpcsetup/phase1.go new file mode 100644 index 0000000000..fa91cc5025 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-633/mpcsetup/phase2.go b/backend/groth16/bw6-633/mpcsetup/phase2.go new file mode 100644 index 0000000000..cdf0bb7578 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bw6-633" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-633/mpcsetup/setup.go b/backend/groth16/bw6-633/mpcsetup/setup.go new file mode 100644 index 0000000000..35bfdca0e9 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bw6-633" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bw6-633/mpcsetup/setup_test.go b/backend/groth16/bw6-633/mpcsetup/setup_test.go new file mode 100644 index 0000000000..fa51d16fe2 --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + cs "github.com/consensys/gnark/constraint/bw6-633" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bw6-633/mpcsetup/utils.go b/backend/groth16/bw6-633/mpcsetup/utils.go new file mode 100644 index 0000000000..7a2979c53c --- /dev/null +++ b/backend/groth16/bw6-633/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bw6-633/groth16/prove.go b/backend/groth16/bw6-633/prove.go similarity index 62% rename from internal/backend/bw6-633/groth16/prove.go rename to backend/groth16/bw6-633/prove.go index 335a9f0ccb..b92dbb6943 100644 --- a/internal/backend/bw6-633/groth16/prove.go +++ b/backend/groth16/bw6-633/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bw6-633" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bw6-633" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bw6-633/groth16/setup.go b/backend/groth16/bw6-633/setup.go similarity index 75% rename from internal/backend/bw6-633/groth16/setup.go rename to backend/groth16/bw6-633/setup.go index d6777090eb..f168993476 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/backend/groth16/bw6-633/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bw6-633/groth16/verify.go b/backend/groth16/bw6-633/verify.go similarity index 64% rename from internal/backend/bw6-633/groth16/verify.go rename to backend/groth16/bw6-633/verify.go index 642bedc2d4..c32a71e6a6 100644 --- a/internal/backend/bw6-633/groth16/verify.go +++ b/backend/groth16/bw6-633/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/backend/bw6-761/groth16/commitment.go b/backend/groth16/bw6-761/commitment.go similarity index 74% rename from internal/backend/bw6-761/groth16/commitment.go rename to backend/groth16/bw6-761/commitment.go index c332669a12..5c357c24ad 100644 --- a/internal/backend/bw6-761/groth16/commitment.go +++ b/backend/groth16/bw6-761/commitment.go @@ -23,7 +23,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } diff --git a/internal/backend/bw6-761/groth16/commitment_test.go b/backend/groth16/bw6-761/commitment_test.go similarity index 91% rename from internal/backend/bw6-761/groth16/commitment_test.go rename to backend/groth16/bw6-761/commitment_test.go index ccddf6be25..f16cb25786 100644 --- a/internal/backend/bw6-761/groth16/commitment_test.go +++ b/backend/groth16/bw6-761/commitment_test.go @@ -17,6 +17,9 @@ package groth16_test import ( + "fmt" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -24,7 +27,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -33,7 +35,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -119,8 +125,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/backend/bw6-761/groth16/marshal.go b/backend/groth16/bw6-761/marshal.go similarity index 82% rename from internal/backend/bw6-761/groth16/marshal.go rename to backend/groth16/bw6-761/marshal.go index 0a18c29463..6f072bc90e 100644 --- a/internal/backend/bw6-761/groth16/marshal.go +++ b/backend/groth16/bw6-761/marshal.go @@ -18,6 +18,9 @@ package groth16 import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark/internal/utils" "io" ) @@ -78,14 +81,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -124,6 +137,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -133,13 +154,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -169,15 +202,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -226,6 +260,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -234,6 +269,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -260,6 +312,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -291,6 +344,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/backend/bw6-761/groth16/marshal_test.go b/backend/groth16/bw6-761/marshal_test.go similarity index 78% rename from internal/backend/bw6-761/groth16/marshal_test.go rename to backend/groth16/bw6-761/marshal_test.go index bccf077d51..36e5f692a4 100644 --- a/internal/backend/bw6-761/groth16/marshal_test.go +++ b/backend/groth16/bw6-761/marshal_test.go @@ -21,11 +21,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "bytes" "math/big" "reflect" "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "testing" @@ -87,13 +93,9 @@ func TestProofSerialization(t *testing.T) { } func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - properties := gopter.NewProperties(parameters) - - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -121,6 +123,21 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) if err != nil { @@ -158,7 +175,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -173,7 +205,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -202,6 +234,19 @@ func TestProvingKeySerialization(t *testing.T) { pk.InfinityB = make([]bool, nbWires) pk.InfinityA[2] = true + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } + var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) if err != nil { @@ -242,6 +287,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/backend/groth16/bw6-761/mpcsetup/lagrange.go b/backend/groth16/bw6-761/mpcsetup/lagrange.go new file mode 100644 index 0000000000..3b8aaa3b7c --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/lagrange.go @@ -0,0 +1,216 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark/internal/utils" +) + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/backend/groth16/bw6-761/mpcsetup/marshal.go b/backend/groth16/bw6-761/mpcsetup/marshal.go new file mode 100644 index 0000000000..d1a071aa5d --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/marshal.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "io" +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/backend/groth16/bw6-761/mpcsetup/marshal_test.go b/backend/groth16/bw6-761/mpcsetup/marshal_test.go new file mode 100644 index 0000000000..cdb362ab70 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/marshal_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "io" + "reflect" + "testing" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/backend/groth16/bw6-761/mpcsetup/phase1.go b/backend/groth16/bw6-761/mpcsetup/phase1.go new file mode 100644 index 0000000000..07af196d3f --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/phase1.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "math" + "math/big" +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-761/mpcsetup/phase2.go b/backend/groth16/bw6-761/mpcsetup/phase2.go new file mode 100644 index 0000000000..cb0c6f9768 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/phase2.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "crypto/sha256" + "errors" + "math/big" + + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bw6-761" +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/backend/groth16/bw6-761/mpcsetup/setup.go b/backend/groth16/bw6-761/mpcsetup/setup.go new file mode 100644 index 0000000000..9008fd26b0 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/setup.go @@ -0,0 +1,97 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + groth16 "github.com/consensys/gnark/backend/groth16/bw6-761" +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/backend/groth16/bw6-761/mpcsetup/setup_test.go b/backend/groth16/bw6-761/mpcsetup/setup_test.go new file mode 100644 index 0000000000..83994ca73d --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/setup_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + cs "github.com/consensys/gnark/constraint/bw6-761" + "testing" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + if testing.Short() { + t.Skip() + } + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/backend/groth16/bw6-761/mpcsetup/utils.go b/backend/groth16/bw6-761/mpcsetup/utils.go new file mode 100644 index 0000000000..dfdd1e8a97 --- /dev/null +++ b/backend/groth16/bw6-761/mpcsetup/utils.go @@ -0,0 +1,170 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package mpcsetup + +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + curve "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/backend/bw6-761/groth16/prove.go b/backend/groth16/bw6-761/prove.go similarity index 62% rename from internal/backend/bw6-761/groth16/prove.go rename to backend/groth16/bw6-761/prove.go index 950932327f..3ee6b9ad0f 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/backend/groth16/bw6-761/prove.go @@ -17,13 +17,17 @@ package groth16 import ( - "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -35,9 +39,10 @@ import ( // with a valid statement and a VerifyingKey // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type Proof struct { - Ar, Krs curve.G1Affine - Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Ar, Krs curve.G1Affine + Bs curve.G2Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -51,72 +56,78 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() @@ -203,15 +214,19 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC var krs, krs2, p1 curve.G1Jac chKrs2Done := make(chan error, 1) + sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { - _, err := krs2.MultiExp(pk.G1.Z, h, ecc.MultiExpConfig{NbTasks: n / 2}) + _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) chKrs2Done <- err }() - // filter the wire values if needed; - _wireValues := filter(wireValues, r1cs.CommitmentInfo.PrivateToPublic()) + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues[r1cs.GetNbPublicVariables():], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { chKrsDone <- err return } @@ -292,26 +307,32 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverC } // if len(toRemove) == 0, returns slice -// else, returns a new slice without the indexes in toRemove -// this assumes toRemove indexes are sorted and len(slice) > len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) - j := 0 + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) + // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i := 0; i < len(slice); i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -334,9 +355,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -344,7 +365,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -354,7 +375,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } diff --git a/internal/backend/bw6-761/groth16/setup.go b/backend/groth16/bw6-761/setup.go similarity index 75% rename from internal/backend/bw6-761/groth16/setup.go rename to backend/groth16/bw6-761/setup.go index c0616ff53a..b0fa2811e6 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/backend/groth16/bw6-761/setup.go @@ -17,13 +17,15 @@ package groth16 import ( + "errors" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "math/big" "math/bits" ) @@ -34,15 +36,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -52,21 +54,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -75,8 +77,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -93,17 +95,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -137,7 +142,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -148,28 +156,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -222,11 +244,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -238,8 +262,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality) - 1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -252,17 +277,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -279,15 +309,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -298,16 +326,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { + if err := vk.Precompute(); err != nil { return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -322,7 +363,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -366,8 +407,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -380,9 +423,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -436,7 +482,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -448,8 +497,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -503,6 +552,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys, _, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -514,7 +579,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -522,6 +589,7 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -606,7 +674,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/backend/bw6-761/groth16/verify.go b/backend/groth16/bw6-761/verify.go similarity index 64% rename from internal/backend/bw6-761/groth16/verify.go rename to backend/groth16/bw6-761/verify.go index d1644cdb3b..ca49685a5e 100644 --- a/internal/backend/bw6-761/groth16/verify.go +++ b/backend/groth16/bw6-761/verify.go @@ -22,9 +22,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" "io" - "math/big" "time" ) @@ -36,10 +38,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K)-1) } @@ -62,21 +62,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -87,8 +98,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index 1aefdfa072..41e0f63c7e 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -44,13 +44,13 @@ import ( gnarkio "github.com/consensys/gnark/io" - groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - groth16_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/groth16" - groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - groth16_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/groth16" - groth16_bn254 "github.com/consensys/gnark/internal/backend/bn254/groth16" - groth16_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/groth16" - groth16_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/groth16" + groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377" + groth16_bls12381 "github.com/consensys/gnark/backend/groth16/bls12-381" + groth16_bls24315 "github.com/consensys/gnark/backend/groth16/bls24-315" + groth16_bls24317 "github.com/consensys/gnark/backend/groth16/bls24-317" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + groth16_bw6633 "github.com/consensys/gnark/backend/groth16/bw6-633" + groth16_bw6761 "github.com/consensys/gnark/backend/groth16/bw6-761" ) type groth16Object interface { @@ -168,55 +168,28 @@ func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error { // internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { - // apply options - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return nil, err - } - switch _r1cs := r1cs.(type) { case *cs_bls12377.R1CS: - w, ok := fullWitness.Vector().(fr_bls12377.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), fullWitness, opts...) + case *cs_bls12381.R1CS: - w, ok := fullWitness.Vector().(fr_bls12381.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), fullWitness, opts...) + case *cs_bn254.R1CS: - w, ok := fullWitness.Vector().(fr_bn254.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), fullWitness, opts...) + case *cs_bw6761.R1CS: - w, ok := fullWitness.Vector().(fr_bw6761.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), fullWitness, opts...) + case *cs_bls24317.R1CS: - w, ok := fullWitness.Vector().(fr_bls24317.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), w, opt) + return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), fullWitness, opts...) + case *cs_bls24315.R1CS: - w, ok := fullWitness.Vector().(fr_bls24315.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), fullWitness, opts...) + case *cs_bw6633.R1CS: - w, ok := fullWitness.Vector().(fr_bw6633.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), w, opt) + return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), fullWitness, opts...) + default: panic("unrecognized R1CS curve type") } diff --git a/backend/groth16/internal/test_utils/test_utils.go b/backend/groth16/internal/test_utils/test_utils.go new file mode 100644 index 0000000000..4a76414298 --- /dev/null +++ b/backend/groth16/internal/test_utils/test_utils.go @@ -0,0 +1,14 @@ +package test_utils + +import "math/rand" + +func Random2DIntSlice(maxN, maxM int) [][]int { + res := make([][]int, rand.Intn(maxN)) //#nosec G404 weak rng OK for test + for i := range res { + res[i] = make([]int, rand.Intn(maxM)) //#nosec G404 weak rng OK for test + for j := range res[i] { + res[i][j] = rand.Int() //#nosec G404 weak rng OK for test + } + } + return res +} diff --git a/backend/groth16/internal/utils.go b/backend/groth16/internal/utils.go new file mode 100644 index 0000000000..6062ef57ce --- /dev/null +++ b/backend/groth16/internal/utils.go @@ -0,0 +1,22 @@ +package internal + +func ConcatAll(slices ...[]int) []int { // copyright note: written by GitHub Copilot + totalLen := 0 + for _, s := range slices { + totalLen += len(s) + } + res := make([]int, totalLen) + i := 0 + for _, s := range slices { + i += copy(res[i:], s) + } + return res +} + +func NbElements(slices [][]int) int { // copyright note: written by GitHub Copilot + totalLen := 0 + for _, s := range slices { + totalLen += len(s) + } + return totalLen +} diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go deleted file mode 100644 index 4d0ee814d3..0000000000 --- a/backend/hint/builtin.go +++ /dev/null @@ -1,25 +0,0 @@ -package hint - -import ( - "math/big" -) - -func init() { - Register(InvZero) -} - -// InvZero computes the value 1/a for the single input a. If a == 0, returns 0. -func InvZero(q *big.Int, inputs []*big.Int, results []*big.Int) error { - result := results[0] - - // save input - result.Set(inputs[0]) - - // a == 0, return - if result.IsUint64() && result.Uint64() == 0 { - return nil - } - - result.ModInverse(result, q) - return nil -} diff --git a/backend/hint/hint.go b/backend/hint/hint.go deleted file mode 100644 index 60976a3dc8..0000000000 --- a/backend/hint/hint.go +++ /dev/null @@ -1,102 +0,0 @@ -/* -Package hint allows to define computations outside of a circuit. - -Usually, it is expected that computations in circuits are performed on -variables. However, in some cases defining the computations in circuits may be -complicated or computationally expensive. By using hints, the computations are -performed outside of the circuit on integers (compared to the frontend.Variable -values inside the circuits) and the result of a hint function is assigned to a -newly created variable in a circuit. - -As the computations are perfomed outside of the circuit, then the correctness of -the result is not guaranteed. This also means that the result of a hint function -is unconstrained by default, leading to failure while composing circuit proof. -Thus, it is the circuit developer responsibility to verify the correctness hint -result by adding necessary constraints in the circuit. - -As an example, lets say the hint function computes a factorization of a -semiprime n: - - p, q <- hint(n) st. p * q = n - -into primes p and q. Then, the circuit developer needs to assert in the circuit -that p*q indeed equals to n: - - n == p * q. - -However, if the hint function is incorrectly defined (e.g. in the previous -example, it returns 1 and n instead of p and q), then the assertion may still -hold, but the constructed proof is semantically invalid. Thus, the user -constructing the proof must be extremely cautious when using hints. - -# Using hint functions in circuits - -To use a hint function in a circuit, the developer first needs to define a hint -function hintFn according to the Function interface. Then, in a circuit, the -developer applies the hint function with frontend.API.NewHint(hintFn, vars...), -where vars are the variables the hint function will be applied to (and -correspond to the argument inputs in the Function type) which returns a new -unconstrained variable. The returned variables must be constrained using -frontend.API.Assert[.*] methods. - -As explained, the hints are essentially black boxes from the circuit point of -view and thus the defined hints in circuits are not used when constructing a -proof. To allow the particular hint functions to be used during proof -construction, the user needs to supply a backend.ProverOption indicating the -enabled hints. Such options can be optained by a call to -backend.WithHints(hintFns...), where hintFns are the corresponding hint -functions. - -# Using hint functions in gadgets - -Similar considerations apply for hint functions used in gadgets as in -user-defined circuits. However, listing all hint functions used in a particular -gadget for constructing backend.ProverOption puts high overhead for the user to -enable all necessary hints. - -For that, this package also provides a registry of trusted hint functions. When -a gadget registers a hint function, then it is automatically enabled during -proof computation and the prover does not need to provide a corresponding -proving option. - -In the init() method of the gadget, call the method Register(hintFn) method on -the hint function hintFn to register a hint function in the package registry. -*/ -package hint - -import ( - "hash/fnv" - "math/big" - "reflect" - "runtime" -) - -// ID is a unique identifier for a hint function used for lookup. -type ID uint32 - -// Function defines an annotated hint function; the number of inputs and outputs injected at solving -// time is defined in the circuit (compile time). -// -// For example: -// -// b := api.NewHint(hint, 2, a) -// --> at solving time, hint is going to be invoked with 1 input (a) and is expected to return 2 outputs -// b[0] and b[1]. -type Function func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error - -// UUID is a reference function for computing the hint ID based on a function name -func UUID(fn Function) ID { - hf := fnv.New32a() - name := Name(fn) - - // TODO relying on name to derive UUID is risky; if fn is an anonymous func, wil be package.glob..funcN - // and if new anonymous functions are added in the package, N may change, so will UUID. - hf.Write([]byte(name)) // #nosec G104 -- does not err - - return ID(hf.Sum32()) -} - -func Name(fn Function) string { - fnptr := reflect.ValueOf(fn).Pointer() - return runtime.FuncForPC(fnptr).Name() -} diff --git a/backend/hint/registry.go b/backend/hint/registry.go deleted file mode 100644 index dfd5a66b8d..0000000000 --- a/backend/hint/registry.go +++ /dev/null @@ -1,37 +0,0 @@ -package hint - -import ( - "sync" - - "github.com/consensys/gnark/logger" -) - -var registry = make(map[ID]Function) -var registryM sync.RWMutex - -// Register registers a hint function in the global registry. -func Register(hintFns ...Function) { - registryM.Lock() - defer registryM.Unlock() - for _, hintFn := range hintFns { - key := UUID(hintFn) - name := Name(hintFn) - if _, ok := registry[key]; ok { - log := logger.Logger() - log.Warn().Str("name", name).Msg("function registered multiple times") - return - } - registry[key] = hintFn - } -} - -// GetRegistered returns all registered hint functions. -func GetRegistered() []Function { - registryM.RLock() - defer registryM.RUnlock() - ret := make([]Function, 0, len(registry)) - for _, v := range registry { - ret = append(ret, v) - } - return ret -} diff --git a/backend/plonk/bls12-377/marshal.go b/backend/plonk/bls12-377/marshal.go new file mode 100644 index 0000000000..3d5eebd1d5 --- /dev/null +++ b/backend/plonk/bls12-377/marshal.go @@ -0,0 +1,387 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package plonk + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" + "io" +) + +// WriteRawTo writes binary encoding of Proof to w without point compression +func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { + return proof.writeTo(w, curve.RawEncoding()) +} + +// WriteTo writes binary encoding of Proof to w with point compression +func (proof *Proof) WriteTo(w io.Writer) (int64, error) { + return proof.writeTo(w) +} + +func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { + enc := curve.NewEncoder(w, options...) + + toEncode := []interface{}{ + &proof.LRO[0], + &proof.LRO[1], + &proof.LRO[2], + &proof.Z, + &proof.H[0], + &proof.H[1], + &proof.H[2], + &proof.BatchedProof.H, + proof.BatchedProof.ClaimedValues, + &proof.ZShiftedOpening.H, + &proof.ZShiftedOpening.ClaimedValue, + proof.Bsb22Commitments, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom reads binary representation of Proof from r +func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { + dec := curve.NewDecoder(r) + toDecode := []interface{}{ + &proof.LRO[0], + &proof.LRO[1], + &proof.LRO[2], + &proof.Z, + &proof.H[0], + &proof.H[1], + &proof.H[2], + &proof.BatchedProof.H, + &proof.BatchedProof.ClaimedValues, + &proof.ZShiftedOpening.H, + &proof.ZShiftedOpening.ClaimedValue, + &proof.Bsb22Commitments, + } + + for _, v := range toDecode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + if proof.Bsb22Commitments == nil { + proof.Bsb22Commitments = []kzg.Digest{} + } + + return dec.BytesRead(), nil +} + +// WriteTo writes binary encoding of ProvingKey to w +func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { + return pk.writeTo(w, true) +} + +// WriteRawTo writes binary encoding of ProvingKey to w without point compression +func (pk *ProvingKey) WriteRawTo(w io.Writer) (n int64, err error) { + return pk.writeTo(w, false) +} + +func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err error) { + // encode the verifying key + if withCompression { + n, err = pk.Vk.WriteTo(w) + } else { + n, err = pk.Vk.WriteRawTo(w) + } + if err != nil { + return + } + + // fft domains + n2, err := pk.Domain[0].WriteTo(w) + if err != nil { + return + } + n += n2 + + n2, err = pk.Domain[1].WriteTo(w) + if err != nil { + return + } + n += n2 + + // KZG key + if withCompression { + n2, err = pk.Kzg.WriteTo(w) + } else { + n2, err = pk.Kzg.WriteRawTo(w) + } + if err != nil { + return + } + n += n2 + + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { + return n, errors.New("invalid permutation size, expected 3*domain cardinality") + } + + enc := curve.NewEncoder(w) + // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't + // encode the size (nor does it convert from Montgomery to Regular form) + // so we explicitly transmit []fr.Element + toEncode := []interface{}{ + pk.trace.Ql.Coefficients(), + pk.trace.Qr.Coefficients(), + pk.trace.Qm.Coefficients(), + pk.trace.Qo.Coefficients(), + pk.trace.Qk.Coefficients(), + coefficients(pk.trace.Qcp), + pk.lQk.Coefficients(), + pk.trace.S1.Coefficients(), + pk.trace.S2.Coefficients(), + pk.trace.S3.Coefficients(), + pk.trace.S, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return n + enc.BytesWritten(), err + } + } + + return n + enc.BytesWritten(), nil +} + +// ReadFrom reads from binary representation in r into ProvingKey +func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { + return pk.readFrom(r, true) +} + +// UnsafeReadFrom reads from binary representation in r into ProvingKey without subgroup checks +func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { + return pk.readFrom(r, false) +} + +func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, error) { + pk.Vk = &VerifyingKey{} + n, err := pk.Vk.ReadFrom(r) + if err != nil { + return n, err + } + + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + + if withSubgroupChecks { + n2, err = pk.Kzg.ReadFrom(r) + } else { + n2, err = pk.Kzg.UnsafeReadFrom(r) + } + n += n2 + if err != nil { + return n, err + } + + pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) + + dec := curve.NewDecoder(r) + + var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element + var qcp [][]fr.Element + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error + } + + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err + } + } + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + pk.trace.Ql = iop.NewPolynomial(&ql, canReg) + pk.trace.Qr = iop.NewPolynomial(&qr, canReg) + pk.trace.Qm = iop.NewPolynomial(&qm, canReg) + pk.trace.Qo = iop.NewPolynomial(&qo, canReg) + pk.trace.Qk = iop.NewPolynomial(&qk, canReg) + pk.trace.S1 = iop.NewPolynomial(&s1, canReg) + pk.trace.S2 = iop.NewPolynomial(&s2, canReg) + pk.trace.S3 = iop.NewPolynomial(&s3, canReg) + + pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) + for i := range qcp { + pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) + } + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + pk.lQk = iop.NewPolynomial(&lqk, lagReg) + + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + + pk.computeLagrangeCosetPolys() + + return n + dec.BytesRead(), nil + +} + +// WriteTo writes binary encoding of VerifyingKey to w +func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { + return vk.writeTo(w) +} + +// WriteRawTo writes binary encoding of VerifyingKey to w without point compression +func (vk *VerifyingKey) WriteRawTo(w io.Writer) (int64, error) { + return vk.writeTo(w, curve.RawEncoding()) +} + +func (vk *VerifyingKey) writeTo(w io.Writer, options ...func(*curve.Encoder)) (n int64, err error) { + enc := curve.NewEncoder(w) + + toEncode := []interface{}{ + vk.Size, + &vk.SizeInv, + &vk.Generator, + vk.NbPublicVariables, + &vk.CosetShift, + &vk.S[0], + &vk.S[1], + &vk.S[2], + &vk.Ql, + &vk.Qr, + &vk.Qm, + &vk.Qo, + &vk.Qk, + vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + vk.CommitmentConstraintIndexes, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom reads from binary representation in r into VerifyingKey +func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { + dec := curve.NewDecoder(r) + toDecode := []interface{}{ + &vk.Size, + &vk.SizeInv, + &vk.Generator, + &vk.NbPublicVariables, + &vk.CosetShift, + &vk.S[0], + &vk.S[1], + &vk.S[2], + &vk.Ql, + &vk.Qr, + &vk.Qm, + &vk.Qo, + &vk.Qk, + &vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + &vk.CommitmentConstraintIndexes, + } + + for _, v := range toDecode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + if vk.Qcp == nil { + vk.Qcp = []kzg.Digest{} + } + + return dec.BytesRead(), nil +} diff --git a/backend/plonk/bls12-377/marshal_test.go b/backend/plonk/bls12-377/marshal_test.go new file mode 100644 index 0000000000..9325743aaf --- /dev/null +++ b/backend/plonk/bls12-377/marshal_test.go @@ -0,0 +1,270 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package plonk + +import ( + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + + "bytes" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + gnarkio "github.com/consensys/gnark/io" + "io" + "math/big" + "math/rand" + "reflect" + "testing" +) + +func TestProofSerialization(t *testing.T) { + // create a proof + var proof, reconstructed Proof + proof.randomize() + + roundTripCheck(t, &proof, &reconstructed) +} + +func TestProofSerializationRaw(t *testing.T) { + // create a proof + var proof, reconstructed Proof + proof.randomize() + + roundTripCheckRaw(t, &proof, &reconstructed) +} + +func TestProvingKeySerialization(t *testing.T) { + // random pk + var pk, reconstructed ProvingKey + pk.randomize() + + roundTripCheck(t, &pk, &reconstructed) +} + +func TestProvingKeySerializationRaw(t *testing.T) { + // random pk + var pk, reconstructed ProvingKey + pk.randomize() + + roundTripCheckRaw(t, &pk, &reconstructed) +} + +func TestProvingKeySerializationRawUnsafe(t *testing.T) { + // random pk + var pk, reconstructed ProvingKey + pk.randomize() + + roundTripCheckRawUnsafe(t, &pk, &reconstructed) +} + +func TestVerifyingKeySerialization(t *testing.T) { + // create a random vk + var vk, reconstructed VerifyingKey + vk.randomize() + + roundTripCheck(t, &vk, &reconstructed) +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} + +func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteRawTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} + +type unsafeReaderFrom interface { + UnsafeReadFrom(io.Reader) (int64, error) +} + +func roundTripCheckRawUnsafe(t *testing.T, from gnarkio.WriterRawTo, reconstructed unsafeReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteRawTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.UnsafeReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} + +func (pk *ProvingKey) randomize() { + + var vk VerifyingKey + vk.randomize() + pk.Vk = &vk + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + + pk.Kzg.G1 = make([]curve.G1Affine, 7) + for i := range pk.Kzg.G1 { + pk.Kzg.G1[i] = randomG1Point() + } + + n := int(pk.Domain[0].Cardinality) + ql := randomScalars(n) + qr := randomScalars(n) + qm := randomScalars(n) + qo := randomScalars(n) + qk := randomScalars(n) + lqk := randomScalars(n) + s1 := randomScalars(n) + s2 := randomScalars(n) + s3 := randomScalars(n) + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + pk.trace.Ql = iop.NewPolynomial(&ql, canReg) + pk.trace.Qr = iop.NewPolynomial(&qr, canReg) + pk.trace.Qm = iop.NewPolynomial(&qm, canReg) + pk.trace.Qo = iop.NewPolynomial(&qo, canReg) + pk.trace.Qk = iop.NewPolynomial(&qk, canReg) + pk.trace.S1 = iop.NewPolynomial(&s1, canReg) + pk.trace.S2 = iop.NewPolynomial(&s2, canReg) + pk.trace.S3 = iop.NewPolynomial(&s3, canReg) + + pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here + for i := range pk.trace.Qcp { + qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here + pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) + } + + pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) + pk.trace.S[0] = -12 + pk.trace.S[len(pk.trace.S)-1] = 8888 + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + pk.lQk = iop.NewPolynomial(&lqk, lagReg) + + pk.computeLagrangeCosetPolys() +} + +func (vk *VerifyingKey) randomize() { + vk.Size = rand.Uint64() //#nosec G404 weak rng is fine here + vk.SizeInv.SetRandom() + vk.Generator.SetRandom() + vk.NbPublicVariables = rand.Uint64() //#nosec G404 weak rng is fine here + vk.CommitmentConstraintIndexes = []uint64{rand.Uint64()} //#nosec G404 weak rng is fine here + vk.CosetShift.SetRandom() + + vk.S[0] = randomG1Point() + vk.S[1] = randomG1Point() + vk.S[2] = randomG1Point() + + vk.Kzg.G1 = randomG1Point() + vk.Kzg.G2[0] = randomG2Point() + vk.Kzg.G2[1] = randomG2Point() + + vk.Ql = randomG1Point() + vk.Qr = randomG1Point() + vk.Qm = randomG1Point() + vk.Qo = randomG1Point() + vk.Qk = randomG1Point() + vk.Qcp = randomG1Points(rand.Intn(4)) //#nosec G404 weak rng is fine here +} + +func (proof *Proof) randomize() { + proof.LRO[0] = randomG1Point() + proof.LRO[1] = randomG1Point() + proof.LRO[2] = randomG1Point() + proof.Z = randomG1Point() + proof.H[0] = randomG1Point() + proof.H[1] = randomG1Point() + proof.H[2] = randomG1Point() + proof.BatchedProof.H = randomG1Point() + proof.BatchedProof.ClaimedValues = randomScalars(2) + proof.ZShiftedOpening.H = randomG1Point() + proof.ZShiftedOpening.ClaimedValue.SetRandom() + proof.Bsb22Commitments = randomG1Points(rand.Intn(4)) //#nosec G404 weak rng is fine here +} + +func randomG2Point() curve.G2Affine { + _, _, _, r := curve.Generators() + r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) //#nosec G404 weak rng is fine here + return r +} + +func randomG1Point() curve.G1Affine { + _, _, r, _ := curve.Generators() + r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) //#nosec G404 weak rng is fine here + return r +} + +func randomG1Points(n int) []curve.G1Affine { + res := make([]curve.G1Affine, n) + for i := range res { + res[i] = randomG1Point() + } + return res +} + +func randomScalars(n int) []fr.Element { + v := make([]fr.Element, n) + one := fr.One() + for i := 0; i < len(v); i++ { + if i == 0 { + v[i].SetRandom() + } else { + v[i].Add(&v[i-1], &one) + } + } + return v +} diff --git a/backend/plonk/bls12-377/prove.go b/backend/plonk/bls12-377/prove.go new file mode 100644 index 0000000000..3f767215ba --- /dev/null +++ b/backend/plonk/bls12-377/prove.go @@ -0,0 +1,728 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package plonk + +import ( + "crypto/sha256" + "math/big" + "runtime" + "sync" + "time" + + "github.com/consensys/gnark/backend/witness" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + + curve "github.com/consensys/gnark-crypto/ecc/bls12-377" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + cs "github.com/consensys/gnark/constraint/bls12-377" + + "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" +) + +type Proof struct { + + // Commitments to the solution vectors + LRO [3]kzg.Digest + + // Commitment to Z, the permutation polynomial + Z kzg.Digest + + // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial + H [3]kzg.Digest + + Bsb22Commitments []kzg.Digest + + // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2, qCPrime + BatchedProof kzg.BatchOpeningProof + + // Opening proof of Z at zeta*mu + ZShiftedOpening kzg.OpeningProof +} + +// Computing and verifying Bsb22 multi-commits explained in https://hackmd.io/x8KsadW3RRyX7YTCFJIkHg +func bsb22ComputeCommitmentHint(spr *cs.SparseR1CS, pk *ProvingKey, proof *Proof, cCommitments []*iop.Polynomial, res *fr.Element, commDepth int) solver.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] + committedValues := make([]fr.Element, pk.Domain[0].Cardinality) + offset := spr.GetNbPublicVariables() + for i := range ins { + committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) + } + var ( + err error + hashRes []fr.Element + ) + if _, err = committedValues[offset+commitmentInfo.CommitmentIndex].SetRandom(); err != nil { // Commitment injection constraint has qcp = 0. Safe to use for blinding. + return err + } + if _, err = committedValues[offset+spr.GetNbConstraints()-1].SetRandom(); err != nil { // Last constraint has qcp = 0. Safe to use for blinding + return err + } + pi2iop := iop.NewPolynomial(&committedValues, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}) + cCommitments[commDepth] = pi2iop.ShallowClone() + cCommitments[commDepth].ToCanonical(&pk.Domain[0]).ToRegular() + if proof.Bsb22Commitments[commDepth], err = kzg.Commit(cCommitments[commDepth].Coefficients(), pk.Kzg); err != nil { + return err + } + if hashRes, err = fr.Hash(proof.Bsb22Commitments[commDepth].Marshal(), []byte("BSB22-Plonk"), 1); err != nil { + return err + } + res.Set(&hashRes[0]) // TODO @Tabaie use CommitmentIndex for this; create a new variable CommitmentConstraintIndex for other uses + res.BigInt(outs[0]) + + return nil + } +} + +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + + log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", spr.GetNbConstraints()).Str("backend", "plonk").Logger() + + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + start := time.Now() + // pick a hash function that will be used to derive the challenges + hFunc := sha256.New() + + // create a transcript manager to apply Fiat Shamir + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") + + // result + proof := &Proof{} + + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + commitmentVal := make([]fr.Element, len(commitmentInfo)) // TODO @Tabaie get rid of this + cCommitments := make([]*iop.Polynomial, len(commitmentInfo)) + proof.Bsb22Commitments = make([]kzg.Digest, len(commitmentInfo)) + for i := range commitmentInfo { + opt.SolverOpts = append(opt.SolverOpts, solver.OverrideHint(commitmentInfo[i].HintID, + bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) + } + + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + + // query l, r, o in Lagrange basis, not blinded + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err + } + // TODO @gbotrel deal with that conversion lazily + lcCommitments := make([]*iop.Polynomial, len(cCommitments)) + for i := range cCommitments { + lcCommitments[i] = cCommitments[i].Clone(int(pk.Domain[1].Cardinality)).ToLagrangeCoset(&pk.Domain[1]) // lagrange coset form + } + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + // l, r, o and blinded versions + var ( + wliop, + wriop, + woiop, + bwliop, + bwriop, + bwoiop *iop.Polynomial + ) + var wgLRO sync.WaitGroup + wgLRO.Add(3) + go func() { + // we keep in lagrange regular form since iop.BuildRatioCopyConstraint prefers it in this form. + wliop = iop.NewPolynomial(&evaluationLDomainSmall, lagReg) + // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. + bwliop = wliop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + wriop = iop.NewPolynomial(&evaluationRDomainSmall, lagReg) + bwriop = wriop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + woiop = iop.NewPolynomial(&evaluationODomainSmall, lagReg) + bwoiop = woiop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } + + // start computing lcqk + var lcqk *iop.Polynomial + chLcqk := make(chan struct{}, 1) + go func() { + // compute qk in canonical basis, completed with the public inputs + // We copy the coeffs of qk to pk is not mutated + lqkcoef := pk.lQk.Coefficients() + qkCompletedCanonical := make([]fr.Element, len(lqkcoef)) + copy(qkCompletedCanonical, fw[:len(spr.Public)]) + copy(qkCompletedCanonical[len(spr.Public):], lqkcoef[len(spr.Public):]) + for i := range commitmentInfo { + qkCompletedCanonical[spr.GetNbPublicVariables()+commitmentInfo[i].CommitmentIndex] = commitmentVal[i] + } + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + lcqk = iop.NewPolynomial(&qkCompletedCanonical, canReg) + lcqk.ToLagrangeCoset(&pk.Domain[1]) + close(chLcqk) + }() + + // The first challenge is derived using the public data: the commitments to the permutation, + // the coefficients of the circuit, and the public inputs. + // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) + if err := bindPublicData(&fs, "gamma", *pk.Vk, fw[:len(spr.Public)], proof.Bsb22Commitments); err != nil { + return nil, err + } + + // wait for polys to be blinded + wgLRO.Wait() + if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Kzg); err != nil { + return nil, err + } + + gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) // TODO @Tabaie @ThomasPiellard add BSB commitment here? + if err != nil { + return nil, err + } + + // Fiat Shamir this + bbeta, err := fs.ComputeChallenge("beta") + if err != nil { + return nil, err + } + var beta fr.Element + beta.SetBytes(bbeta) + + // l, r, o are already blinded + wgLRO.Add(3) + go func() { + bwliop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwriop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwoiop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + + // compute the copy constraint's ratio + // note that wliop, wriop and woiop are fft'ed (mutated) in the process. + ziop, err := iop.BuildRatioCopyConstraint( + []*iop.Polynomial{ + wliop, + wriop, + woiop, + }, + pk.trace.S, + beta, + gamma, + iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, + &pk.Domain[0], + ) + if err != nil { + return proof, err + } + + // commit to the blinded version of z + chZ := make(chan error, 1) + var bwziop, bwsziop *iop.Polynomial + var alpha fr.Element + go func() { + bwziop = ziop // iop.NewWrappedPolynomial(&ziop) + bwziop.Blind(2) + proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Kzg, runtime.NumCPU()*2) + if err != nil { + chZ <- err + } + + // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) + alpha, err = deriveRandomness(&fs, "alpha", &proof.Z) + if err != nil { + chZ <- err + } + + // Store z(g*x), without reallocating a slice + bwsziop = bwziop.ShallowClone().Shift(1) + bwsziop.ToLagrangeCoset(&pk.Domain[1]) + chZ <- nil + close(chZ) + }() + + // Full capture using latest gnark crypto... + fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element, pi2QcPrime []fr.Element) fr.Element { // TODO @Tabaie make use of the fact that qCPrime is a selector: sparse and binary + + var ic, tmp fr.Element + + ic.Mul(&fql, &l) + tmp.Mul(&fqr, &r) + ic.Add(&ic, &tmp) + tmp.Mul(&fqm, &l).Mul(&tmp, &r) + ic.Add(&ic, &tmp) + tmp.Mul(&fqo, &o) + ic.Add(&ic, &tmp).Add(&ic, &fqk) + nbComms := len(commitmentInfo) + for i := range commitmentInfo { + tmp.Mul(&pi2QcPrime[i], &pi2QcPrime[i+nbComms]) + ic.Add(&ic, &tmp) + } + + return ic + } + + fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { + u := &pk.Domain[0].FrMultiplicativeGen + var a, b, tmp fr.Element + b.Mul(&beta, &fid) + a.Add(&b, &l).Add(&a, &gamma) + b.Mul(&b, u) + tmp.Add(&b, &r).Add(&tmp, &gamma) + a.Mul(&a, &tmp) + tmp.Mul(&b, u).Add(&tmp, &o).Add(&tmp, &gamma) + a.Mul(&a, &tmp).Mul(&a, &fz) + + b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) + tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) + b.Mul(&b, &tmp) + tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) + b.Mul(&b, &tmp).Mul(&b, &fzs) + + b.Sub(&b, &a) + + return b + } + + fone := func(fz, flone fr.Element) fr.Element { + one := fr.One() + one.Sub(&fz, &one).Mul(&one, &flone) + return one + } + + // 0 , 1 , 2, 3 , 4 , 5 , 6 , 7, 8 , 9 , 10, 11, 12, 13, 14, 15:15+nbComm , 15+nbComm:15+2×nbComm + // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk ,lone, Bsb22Commitments, qCPrime + fm := func(x ...fr.Element) fr.Element { + + a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2], x[15:]) + b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) + c := fone(x[7], x[14]) + + c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) + + return c + } + + // wait for lcqk + <-chLcqk + + // wait for Z part + if err := <-chZ; err != nil { + return proof, err + } + + // wait for l, r o lagrange coset conversion + wgLRO.Wait() + + toEval := []*iop.Polynomial{ + bwliop, + bwriop, + bwoiop, + pk.lcIdIOP, + pk.lcS1, + pk.lcS2, + pk.lcS3, + bwziop, + bwsziop, + pk.lcQl, + pk.lcQr, + pk.lcQm, + pk.lcQo, + lcqk, + pk.lLoneIOP, + } + toEval = append(toEval, lcCommitments...) // TODO: Add this at beginning + toEval = append(toEval, pk.lcQcp...) + systemEvaluation, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, toEval...) + if err != nil { + return nil, err + } + // open blinded Z at zeta*z + chbwzIOP := make(chan struct{}, 1) + go func() { + bwziop.ToCanonical(&pk.Domain[1]).ToRegular() + close(chbwzIOP) + }() + + h, err := iop.DivideByXMinusOne(systemEvaluation, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) // TODO Rename to DivideByXNMinusOne or DivideByVanishingPoly etc + if err != nil { + return nil, err + } + + // compute kzg commitments of h1, h2 and h3 + if err := commitToQuotient( + h.Coefficients()[:pk.Domain[0].Cardinality+2], + h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], + h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], + proof, pk.Kzg); err != nil { + return nil, err + } + + // derive zeta + zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) + if err != nil { + return nil, err + } + + // compute evaluations of (blinded version of) l, r, o, z, qCPrime at zeta + var blzeta, brzeta, bozeta fr.Element + qcpzeta := make([]fr.Element, len(commitmentInfo)) + + var wgEvals sync.WaitGroup + wgEvals.Add(3) + evalAtZeta := func(poly *iop.Polynomial, res *fr.Element) { + poly.ToCanonical(&pk.Domain[1]).ToRegular() + *res = poly.Evaluate(zeta) + wgEvals.Done() + } + go evalAtZeta(bwliop, &blzeta) + go evalAtZeta(bwriop, &brzeta) + go evalAtZeta(bwoiop, &bozeta) + evalQcpAtZeta := func(begin, end int) { + for i := begin; i < end; i++ { + qcpzeta[i] = pk.trace.Qcp[i].Evaluate(zeta) + } + } + utils.Parallelize(len(commitmentInfo), evalQcpAtZeta) + + var zetaShifted fr.Element + zetaShifted.Mul(&zeta, &pk.Vk.Generator) + <-chbwzIOP + proof.ZShiftedOpening, err = kzg.Open( + bwziop.Coefficients()[:bwziop.BlindedSize()], + zetaShifted, + pk.Kzg, + ) + if err != nil { + return nil, err + } + + // start to compute foldedH and foldedHDigest while computeLinearizedPolynomial runs. + computeFoldedH := make(chan struct{}, 1) + var foldedH []fr.Element + var foldedHDigest kzg.Digest + go func() { + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) + var bZetaPowerm, bSize big.Int + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + var zetaPowerm fr.Element + zetaPowerm.Exp(zeta, &bSize) + zetaPowerm.BigInt(&bZetaPowerm) + foldedHDigest = proof.H[2] + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) + + // foldedH = h1 + ζ*h2 + ζ²*h3 + foldedH = h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] + h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] + utils.Parallelize(len(foldedH), func(start, end int) { + for i := start; i < end; i++ { + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 + } + }) + close(computeFoldedH) + }() + + wgEvals.Wait() // wait for the evaluations + + var ( + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error + ) + + // blinded z evaluated at u*zeta + bzuzeta := proof.ZShiftedOpening.ClaimedValue + + // compute the linearization polynomial r at zeta + // (goal: save committing separately to z, ql, qr, qm, qo, k + // note: we linearizedPolynomialCanonical reuses bwziop memory + linearizedPolynomialCanonical = computeLinearizedPolynomial( + blzeta, + brzeta, + bozeta, + alpha, + beta, + gamma, + zeta, + bzuzeta, + qcpzeta, + bwziop.Coefficients()[:bwziop.BlindedSize()], + coefficients(cCommitments), + pk, + ) + + // TODO this commitment is only necessary to derive the challenge, we should + // be able to avoid doing it and get the challenge in another way + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Kzg, runtime.NumCPU()*2) + if errLPoly != nil { + return nil, errLPoly + } + + // wait for foldedH and foldedHDigest + <-computeFoldedH + + // Batch open the first list of polynomials + polysQcp := coefficients(pk.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + // offset := len(polysQcp) + polysToOpen[0] = foldedH + polysToOpen[1] = linearizedPolynomialCanonical + polysToOpen[2] = bwliop.Coefficients()[:bwliop.BlindedSize()] + polysToOpen[3] = bwriop.Coefficients()[:bwriop.BlindedSize()] + polysToOpen[4] = bwoiop.Coefficients()[:bwoiop.BlindedSize()] + polysToOpen[5] = pk.trace.S1.Coefficients() + polysToOpen[6] = pk.trace.S2.Coefficients() + + digestsToOpen := make([]curve.G1Affine, len(pk.Vk.Qcp)+7) + copy(digestsToOpen[7:], pk.Vk.Qcp) + // offset = len(pk.Vk.Qcp) + digestsToOpen[0] = foldedHDigest + digestsToOpen[1] = linearizedPolynomialDigest + digestsToOpen[2] = proof.LRO[0] + digestsToOpen[3] = proof.LRO[1] + digestsToOpen[4] = proof.LRO[2] + digestsToOpen[5] = pk.Vk.S[0] + digestsToOpen[6] = pk.Vk.S[1] + + proof.BatchedProof, err = kzg.BatchOpenSinglePoint( + polysToOpen, + digestsToOpen, + zeta, + hFunc, + pk.Kzg, + ) + + log.Debug().Dur("took", time.Since(start)).Msg("prover done") + + if err != nil { + return nil, err + } + + return proof, nil + +} + +func coefficients(p []*iop.Polynomial) [][]fr.Element { + res := make([][]fr.Element, len(p)) + for i, pI := range p { + res[i] = pI.Coefficients() + } + return res +} + +// fills proof.LRO with kzg commits of bcl, bcr and bco +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() + var err0, err1, err2 error + chCommit0 := make(chan struct{}, 1) + chCommit1 := make(chan struct{}, 1) + go func() { + proof.LRO[0], err0 = kzg.Commit(bcl, kzgPk, n) + close(chCommit0) + }() + go func() { + proof.LRO[1], err1 = kzg.Commit(bcr, kzgPk, n) + close(chCommit1) + }() + if proof.LRO[2], err2 = kzg.Commit(bco, kzgPk, n); err2 != nil { + return err2 + } + <-chCommit0 + <-chCommit1 + + if err0 != nil { + return err0 + } + + return err1 +} + +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() + var err0, err1, err2 error + chCommit0 := make(chan struct{}, 1) + chCommit1 := make(chan struct{}, 1) + go func() { + proof.H[0], err0 = kzg.Commit(h1, kzgPk, n) + close(chCommit0) + }() + go func() { + proof.H[1], err1 = kzg.Commit(h2, kzgPk, n) + close(chCommit1) + }() + if proof.H[2], err2 = kzg.Commit(h3, kzgPk, n); err2 != nil { + return err2 + } + <-chCommit0 + <-chCommit1 + + if err0 != nil { + return err0 + } + + return err1 +} + +// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// The purpose is to commit and open all in one ql, qr, qm, qo, qk. +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z +// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + + // first part: individual constraints + var rl fr.Element + rl.Mul(&rZeta, &lZeta) + + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) + var s1, s2 fr.Element + chS1 := make(chan struct{}, 1) + go func() { + s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) + close(chS1) + }() + // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) + tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) + <-chS1 + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element + one.SetOne() + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) + frNbElmt.SetUint64(uint64(nbElmt)) + den.Sub(&zeta, &one). + Inverse(&den) + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := pk.trace.S3.Coefficients() + utils.Parallelize(len(blindedZCanonical), func(start, end int) { + + var t, t0, t1 fr.Element + + for i := start; i < end; i++ { + + t.Mul(&blindedZCanonical[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(s3canonical) { + + t0.Mul(&s3canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + + t.Add(&t, &t0) + } + + t.Mul(&t, &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) + + cql := pk.trace.Ql.Coefficients() + cqr := pk.trace.Qr.Coefficients() + cqm := pk.trace.Qm.Coefficients() + cqo := pk.trace.Qo.Coefficients() + cqk := pk.trace.Qk.Coefficients() + if i < len(cqm) { + + t1.Mul(&cqm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&cql[i], &lZeta) + t0.Add(&t0, &t1) + t.Add(&t, &t0) // linPol = linPol + l(ζ)*Ql(X) + + t0.Mul(&cqr[i], &rZeta) + t.Add(&t, &t0) // linPol = linPol + r(ζ)*Qr(X) + + t0.Mul(&cqo[i], &oZeta).Add(&t0, &cqk[i]) + t.Add(&t, &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) + + for j := range qcpZeta { + t0.Mul(&pi2Canonical[j][i], &qcpZeta[j]) + t.Add(&t, &t0) + } + } + + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) + blindedZCanonical[i].Add(&t, &t0) // finish the computation + } + }) + return blindedZCanonical +} diff --git a/backend/plonk/bls12-377/setup.go b/backend/plonk/bls12-377/setup.go new file mode 100644 index 0000000000..9416161b6d --- /dev/null +++ b/backend/plonk/bls12-377/setup.go @@ -0,0 +1,485 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package plonk + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" + "github.com/consensys/gnark/backend/plonk/internal" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bls12-377" + "sync" +) + +// Trace stores a plonk trace as columns +type Trace struct { + + // Constants describing a plonk circuit. The first entries + // of LQk (whose index correspond to the public inputs) are set to 0, and are to be + // completed by the prover. At those indices i (so from 0 to nb_public_variables), LQl[i]=-1 + // so the first nb_public_variables constraints look like this: + // -1*Wire[i] + 0* + 0 . It is zero when the constant coefficient is replaced by Wire[i]. + Ql, Qr, Qm, Qo, Qk *iop.Polynomial + Qcp []*iop.Polynomial + + // Polynomials representing the splitted permutation. The full permutation's support is 3*N where N=nb wires. + // The set of interpolation is of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls12-377/plonk/verify.go b/backend/plonk/bls12-377/verify.go similarity index 77% rename from internal/backend/bls12-377/plonk/verify.go rename to backend/plonk/bls12-377/verify.go index ca61aa05c4..c28a894d02 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/backend/plonk/bls12-377/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls12_377").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls12-377").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls12-381/plonk/verify.go b/backend/plonk/bls12-381/verify.go similarity index 77% rename from internal/backend/bls12-381/plonk/verify.go rename to backend/plonk/bls12-381/verify.go index bb1d6ebedb..3124e92c5e 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/backend/plonk/bls12-381/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls12_381").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls12-381").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls24-315/plonk/verify.go b/backend/plonk/bls24-315/verify.go similarity index 77% rename from internal/backend/bls24-315/plonk/verify.go rename to backend/plonk/bls24-315/verify.go index 9e2965275c..81dc747e0b 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/backend/plonk/bls24-315/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls24_315").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls24-315").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bls24-317/plonk/verify.go b/backend/plonk/bls24-317/verify.go similarity index 77% rename from internal/backend/bls24-317/plonk/verify.go rename to backend/plonk/bls24-317/verify.go index ccd8642cb2..a6a7479e08 100644 --- a/internal/backend/bls24-317/plonk/verify.go +++ b/backend/plonk/bls24-317/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bls24_317").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bls24-317").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/backend/plonk/bn254/solidity.go b/backend/plonk/bn254/solidity.go new file mode 100644 index 0000000000..b43dd3c219 --- /dev/null +++ b/backend/plonk/bn254/solidity.go @@ -0,0 +1,1271 @@ +package plonk + +const tmplSolidityVerifier = `// SPDX-License-Identifier: Apache-2.0 + +// Copyright 2023 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +pragma solidity ^0.8.19; + +contract PlonkVerifier { + + uint256 private constant r_mod = 21888242871839275222246405745257275088548364400416034343698204186575808495617; + uint256 private constant p_mod = 21888242871839275222246405745257275088696311157297823662689037894645226208583; + {{ range $index, $element := .Kzg.G2 }} + uint256 private constant g2_srs_{{ $index }}_x_0 = {{ (fpstr $element.X.A1) }}; + uint256 private constant g2_srs_{{ $index }}_x_1 = {{ (fpstr $element.X.A0) }}; + uint256 private constant g2_srs_{{ $index }}_y_0 = {{ (fpstr $element.Y.A1) }}; + uint256 private constant g2_srs_{{ $index }}_y_1 = {{ (fpstr $element.Y.A0) }}; + {{ end }} + // ----------------------- vk --------------------- + uint256 private constant vk_domain_size = {{ .Size }}; + uint256 private constant vk_inv_domain_size = {{ (frstr .SizeInv) }}; + uint256 private constant vk_omega = {{ (frstr .Generator) }}; + uint256 private constant vk_ql_com_x = {{ (fpstr .Ql.X) }}; + uint256 private constant vk_ql_com_y = {{ (fpstr .Ql.Y) }}; + uint256 private constant vk_qr_com_x = {{ (fpstr .Qr.X) }}; + uint256 private constant vk_qr_com_y = {{ (fpstr .Qr.Y) }}; + uint256 private constant vk_qm_com_x = {{ (fpstr .Qm.X) }}; + uint256 private constant vk_qm_com_y = {{ (fpstr .Qm.Y) }}; + uint256 private constant vk_qo_com_x = {{ (fpstr .Qo.X) }}; + uint256 private constant vk_qo_com_y = {{ (fpstr .Qo.Y) }}; + uint256 private constant vk_qk_com_x = {{ (fpstr .Qk.X) }}; + uint256 private constant vk_qk_com_y = {{ (fpstr .Qk.Y) }}; + {{ range $index, $element := .S }} + uint256 private constant vk_s{{ inc $index }}_com_x = {{ (fpstr $element.X) }}; + uint256 private constant vk_s{{ inc $index }}_com_y = {{ (fpstr $element.Y) }}; + {{ end }} + uint256 private constant vk_coset_shift = 5; + + {{ range $index, $element := .Qcp}} + uint256 private constant vk_selector_commitments_commit_api_{{ $index }}_x = {{ (fpstr $element.X) }}; + uint256 private constant vk_selector_commitments_commit_api_{{ $index }}_y = {{ (fpstr $element.Y) }}; + {{ end }} + + {{ range $index, $element := .CommitmentConstraintIndexes -}} + uint256 private constant vk_index_commit_api_{{ $index }} = {{ $element }}; + {{ end }} + + uint256 private constant vk_nb_commitments_commit_api = {{ len .CommitmentConstraintIndexes }}; + + // ------------------------------------------------ + + // offset proof + uint256 private constant proof_l_com_x = 0x00; + uint256 private constant proof_l_com_y = 0x20; + uint256 private constant proof_r_com_x = 0x40; + uint256 private constant proof_r_com_y = 0x60; + uint256 private constant proof_o_com_x = 0x80; + uint256 private constant proof_o_com_y = 0xa0; + + // h = h_0 + x^{n+2}h_1 + x^{2(n+2)}h_2 + uint256 private constant proof_h_0_x = 0xc0; + uint256 private constant proof_h_0_y = 0xe0; + uint256 private constant proof_h_1_x = 0x100; + uint256 private constant proof_h_1_y = 0x120; + uint256 private constant proof_h_2_x = 0x140; + uint256 private constant proof_h_2_y = 0x160; + + // wire values at zeta + uint256 private constant proof_l_at_zeta = 0x180; + uint256 private constant proof_r_at_zeta = 0x1a0; + uint256 private constant proof_o_at_zeta = 0x1c0; + + //uint256[STATE_WIDTH-1] permutation_polynomials_at_zeta; // Sσ1(zeta),Sσ2(zeta) + uint256 private constant proof_s1_at_zeta = 0x1e0; // Sσ1(zeta) + uint256 private constant proof_s2_at_zeta = 0x200; // Sσ2(zeta) + + //Bn254.G1Point grand_product_commitment; // [z(x)] + uint256 private constant proof_grand_product_commitment_x = 0x220; + uint256 private constant proof_grand_product_commitment_y = 0x240; + + uint256 private constant proof_grand_product_at_zeta_omega = 0x260; // z(w*zeta) + uint256 private constant proof_quotient_polynomial_at_zeta = 0x280; // t(zeta) + uint256 private constant proof_linearised_polynomial_at_zeta = 0x2a0; // r(zeta) + + // Folded proof for the opening of H, linearised poly, l, r, o, s_1, s_2, qcp + uint256 private constant proof_batch_opening_at_zeta_x = 0x2c0; // [Wzeta] + uint256 private constant proof_batch_opening_at_zeta_y = 0x2e0; + + //Bn254.G1Point opening_at_zeta_omega_proof; // [Wzeta*omega] + uint256 private constant proof_opening_at_zeta_omega_x = 0x300; + uint256 private constant proof_opening_at_zeta_omega_y = 0x320; + + uint256 private constant proof_openings_selector_commit_api_at_zeta = 0x340; + // -> next part of proof is + // [ openings_selector_commits || commitments_wires_commit_api] + + // -------- offset state + + // challenges to check the claimed quotient + uint256 private constant state_alpha = 0x00; + uint256 private constant state_beta = 0x20; + uint256 private constant state_gamma = 0x40; + uint256 private constant state_zeta = 0x60; + + // reusable value + uint256 private constant state_alpha_square_lagrange_0 = 0x80; + + // commitment to H + uint256 private constant state_folded_h_x = 0xa0; + uint256 private constant state_folded_h_y = 0xc0; + + // commitment to the linearised polynomial + uint256 private constant state_linearised_polynomial_x = 0xe0; + uint256 private constant state_linearised_polynomial_y = 0x100; + + // Folded proof for the opening of H, linearised poly, l, r, o, s_1, s_2, qcp + uint256 private constant state_folded_claimed_values = 0x120; + + // folded digests of H, linearised poly, l, r, o, s_1, s_2, qcp + // Bn254.G1Point folded_digests; + uint256 private constant state_folded_digests_x = 0x140; + uint256 private constant state_folded_digests_y = 0x160; + + uint256 private constant state_pi = 0x180; + + uint256 private constant state_zeta_power_n_minus_one = 0x1a0; + + uint256 private constant state_gamma_kzg = 0x1c0; + + uint256 private constant state_success = 0x1e0; + uint256 private constant state_check_var = 0x200; // /!\ this slot is used for debugging only + + uint256 private constant state_last_mem = 0x220; + + // -------- errors + uint256 private constant error_string_id = 0x08c379a000000000000000000000000000000000000000000000000000000000; // selector for function Error(string) + + {{ if (gt (len .CommitmentConstraintIndexes) 0 )}} + // -------- utils (for hash_fr) + uint256 private constant bb = 340282366920938463463374607431768211456; // 2**128 + uint256 private constant zero_uint256 = 0; + + uint8 private constant lenInBytes = 48; + uint8 private constant sizeDomain = 11; + uint8 private constant one = 1; + uint8 private constant two = 2; + {{ end }} + + function Verify(bytes calldata proof, uint256[] calldata public_inputs) + public view returns(bool success) { + + assembly { + + let mem := mload(0x40) + let freeMem := add(mem, state_last_mem) + + // sanity checks + check_inputs_size(public_inputs.length, public_inputs.offset) + check_proof_size(proof.length) + check_proof_openings_size(proof.offset) + + // compute the challenges + let prev_challenge_non_reduced + prev_challenge_non_reduced := derive_gamma(proof.offset, public_inputs.length, public_inputs.offset) + prev_challenge_non_reduced := derive_beta(prev_challenge_non_reduced) + prev_challenge_non_reduced := derive_alpha(proof.offset, prev_challenge_non_reduced) + derive_zeta(proof.offset, prev_challenge_non_reduced) + + // evaluation of Z=Xⁿ-1 at ζ, we save this value + let zeta := mload(add(mem, state_zeta)) + let zeta_power_n_minus_one := addmod(pow(zeta, vk_domain_size, freeMem), sub(r_mod, 1), r_mod) + mstore(add(mem, state_zeta_power_n_minus_one), zeta_power_n_minus_one) + + // public inputs contribution + let l_pi := sum_pi_wo_api_commit(public_inputs.offset, public_inputs.length, freeMem) + {{ if (gt (len .CommitmentConstraintIndexes) 0 ) -}} + let l_wocommit := sum_pi_commit(proof.offset, public_inputs.length, freeMem) + l_pi := addmod(l_wocommit, l_pi, r_mod) + {{ end -}} + mstore(add(mem, state_pi), l_pi) + + compute_alpha_square_lagrange_0() + verify_quotient_poly_eval_at_zeta(proof.offset) + fold_h(proof.offset) + compute_commitment_linearised_polynomial(proof.offset) + compute_gamma_kzg(proof.offset) + fold_state(proof.offset) + batch_verify_multi_points(proof.offset) + + success := mload(add(mem, state_success)) + + // Beginning errors ------------------------------------------------- + function error_ec_op() { + let ptError := mload(0x40) + mstore(ptError, error_string_id) // selector for function Error(string) + mstore(add(ptError, 0x4), 0x20) + mstore(add(ptError, 0x24), 0x12) + mstore(add(ptError, 0x44), "error ec operation") + revert(ptError, 0x64) + } + + function error_inputs_size() { + let ptError := mload(0x40) + mstore(ptError, error_string_id) // selector for function Error(string) + mstore(add(ptError, 0x4), 0x20) + mstore(add(ptError, 0x24), 0x18) + mstore(add(ptError, 0x44), "inputs are bigger than r") + revert(ptError, 0x64) + } + + function error_proof_size() { + let ptError := mload(0x40) + mstore(ptError, error_string_id) // selector for function Error(string) + mstore(add(ptError, 0x4), 0x20) + mstore(add(ptError, 0x24), 0x10) + mstore(add(ptError, 0x44), "wrong proof size") + revert(ptError, 0x64) + } + + function error_proof_openings_size() { + let ptError := mload(0x40) + mstore(ptError, error_string_id) // selector for function Error(string) + mstore(add(ptError, 0x4), 0x20) + mstore(add(ptError, 0x24), 0x16) + mstore(add(ptError, 0x44), "openings bigger than r") + revert(ptError, 0x64) + } + + function error_verify() { + let ptError := mload(0x40) + mstore(ptError, error_string_id) // selector for function Error(string) + mstore(add(ptError, 0x4), 0x20) + mstore(add(ptError, 0x24), 0xc) + mstore(add(ptError, 0x44), "error verify") + revert(ptError, 0x64) + } + // end errors ------------------------------------------------- + + // Beginning checks ------------------------------------------------- + + // s number of public inputs, p pointer the public inputs + function check_inputs_size(s, p) { + let input_checks := 1 + for {let i} lt(i, s) {i:=add(i,1)} + { + input_checks := and(input_checks,lt(calldataload(p), r_mod)) + p := add(p, 0x20) + } + if iszero(input_checks) { + error_inputs_size() + } + } + + function check_proof_size(actual_proof_size) { + let expected_proof_size := add(0x340, mul(vk_nb_commitments_commit_api,0x60)) + if iszero(eq(actual_proof_size, expected_proof_size)) { + error_proof_size() + } + } + + function check_proof_openings_size(aproof) { + + let openings_check := 1 + + // linearised polynomial at zeta + let p := add(aproof, proof_linearised_polynomial_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // quotient polynomial at zeta + p := add(aproof, proof_quotient_polynomial_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_l_at_zeta + p := add(aproof, proof_l_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_r_at_zeta + p := add(aproof, proof_r_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_o_at_zeta + p := add(aproof, proof_o_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_s1_at_zeta + p := add(aproof, proof_s1_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_s2_at_zeta + p := add(aproof, proof_s2_at_zeta) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_grand_product_at_zeta_omega + p := add(aproof, proof_grand_product_at_zeta_omega) + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + + // proof_openings_selector_commit_api_at_zeta + + p := add(aproof, proof_openings_selector_commit_api_at_zeta) + for {let i:=0} lt(i, vk_nb_commitments_commit_api) {i:=add(i,1)} + { + openings_check := and(openings_check, lt(calldataload(p), r_mod)) + p := add(p, 0x20) + } + + if iszero(openings_check) { + error_proof_openings_size() + } + + } + // end checks ------------------------------------------------- + + // Beginning challenges ------------------------------------------------- + + // Derive gamma as Sha256() + // where transcript is the concatenation (in this order) of: + // * the word "gamma" in ascii, equal to [0x67,0x61,0x6d, 0x6d, 0x61] and encoded as a uint256. + // * the commitments to the permutation polynomials S1, S2, S3, where we concatenate the coordinates of those points + // * the commitments of Ql, Qr, Qm, Qo, Qk + // * the public inputs + // * the commitments of the wires related to the custom gates (commitments_wires_commit_api) + // * commitments to L, R, O (proof__com_) + // The data described above is written starting at mPtr. "gamma" lies on 5 bytes, + // and is encoded as a uint256 number n. In basis b = 256, the number looks like this + // [0 0 0 .. 0x67 0x61 0x6d, 0x6d, 0x61]. The first non zero entry is at position 27=0x1b + // nb_pi, pi respectively number of public inputs and public inputs + function derive_gamma(aproof, nb_pi, pi)->gamma_not_reduced { + + let state := mload(0x40) + let mPtr := add(state, state_last_mem) + + // gamma + // gamma in ascii is [0x67,0x61,0x6d, 0x6d, 0x61] + // (same for alpha, beta, zeta) + mstore(mPtr, 0x67616d6d61) // "gamma" + + mstore(add(mPtr, 0x20), vk_s1_com_x) + mstore(add(mPtr, 0x40), vk_s1_com_y) + mstore(add(mPtr, 0x60), vk_s2_com_x) + mstore(add(mPtr, 0x80), vk_s2_com_y) + mstore(add(mPtr, 0xa0), vk_s3_com_x) + mstore(add(mPtr, 0xc0), vk_s3_com_y) + mstore(add(mPtr, 0xe0), vk_ql_com_x) + mstore(add(mPtr, 0x100), vk_ql_com_y) + mstore(add(mPtr, 0x120), vk_qr_com_x) + mstore(add(mPtr, 0x140), vk_qr_com_y) + mstore(add(mPtr, 0x160), vk_qm_com_x) + mstore(add(mPtr, 0x180), vk_qm_com_y) + mstore(add(mPtr, 0x1a0), vk_qo_com_x) + mstore(add(mPtr, 0x1c0), vk_qo_com_y) + mstore(add(mPtr, 0x1e0), vk_qk_com_x) + mstore(add(mPtr, 0x200), vk_qk_com_y) + + // public inputs + let _mPtr := add(mPtr, 0x220) + let size_pi_in_bytes := mul(nb_pi, 0x20) + calldatacopy(_mPtr, pi, size_pi_in_bytes) + _mPtr := add(_mPtr, size_pi_in_bytes) + + // wire commitment commit api + let _proof := add(aproof, proof_openings_selector_commit_api_at_zeta) + _proof := add(_proof, mul(vk_nb_commitments_commit_api, 0x20)) + let size_wire_commitments_commit_api_in_bytes := mul(vk_nb_commitments_commit_api, 0x40) + calldatacopy(_mPtr, _proof, size_wire_commitments_commit_api_in_bytes) + _mPtr := add(_mPtr, size_wire_commitments_commit_api_in_bytes) + + // commitments to l, r, o + let size_commitments_lro_in_bytes := 0xc0 + calldatacopy(_mPtr, aproof, size_commitments_lro_in_bytes) + _mPtr := add(_mPtr, size_commitments_lro_in_bytes) + + let size := add(0x2c5, mul(nb_pi, 0x20)) // 0x2c5 = 22*32+5 + size := add(size, mul(vk_nb_commitments_commit_api, 0x40)) + let l_success := staticcall(gas(), 0x2, add(mPtr, 0x1b), size, mPtr, 0x20) //0x1b -> 000.."gamma" + if iszero(l_success) { + error_verify() + } + gamma_not_reduced := mload(mPtr) + mstore(add(state, state_gamma), mod(gamma_not_reduced, r_mod)) + } + + function derive_beta(gamma_not_reduced)->beta_not_reduced{ + + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + + // beta + mstore(mPtr, 0x62657461) // "beta" + mstore(add(mPtr, 0x20), gamma_not_reduced) + let l_success := staticcall(gas(), 0x2, add(mPtr, 0x1c), 0x24, mPtr, 0x20) //0x1b -> 000.."gamma" + if iszero(l_success) { + error_verify() + } + beta_not_reduced := mload(mPtr) + mstore(add(state, state_beta), mod(beta_not_reduced, r_mod)) + } + + // alpha depends on the previous challenge (beta) and on the commitment to the grand product polynomial + function derive_alpha(aproof, beta_not_reduced)->alpha_not_reduced { + + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + + // alpha + mstore(mPtr, 0x616C706861) // "alpha" + mstore(add(mPtr, 0x20), beta_not_reduced) + calldatacopy(add(mPtr, 0x40), add(aproof, proof_grand_product_commitment_x), 0x40) + let l_success := staticcall(gas(), 0x2, add(mPtr, 0x1b), 0x65, mPtr, 0x20) //0x1b -> 000.."gamma" + if iszero(l_success) { + error_verify() + } + alpha_not_reduced := mload(mPtr) + mstore(add(state, state_alpha), mod(alpha_not_reduced, r_mod)) + } + + // zeta depends on the previous challenge (alpha) and on the commitment to the quotient polynomial + function derive_zeta(aproof, alpha_not_reduced) { + + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + + // zeta + mstore(mPtr, 0x7a657461) // "zeta" + mstore(add(mPtr, 0x20), alpha_not_reduced) + calldatacopy(add(mPtr, 0x40), add(aproof, proof_h_0_x), 0xc0) + let l_success := staticcall(gas(), 0x2, add(mPtr, 0x1c), 0xe4, mPtr, 0x20) + if iszero(l_success) { + error_verify() + } + let zeta_not_reduced := mload(mPtr) + mstore(add(state, state_zeta), mod(zeta_not_reduced, r_mod)) + } + // END challenges ------------------------------------------------- + + // BEGINNING compute_pi ------------------------------------------------- + + // public input (not comming from the commit api) contribution + // ins, n are the public inputs and number of public inputs respectively + function sum_pi_wo_api_commit(ins, n, mPtr)->pi_wo_commit { + + let state := mload(0x40) + let z := mload(add(state, state_zeta)) + let zpnmo := mload(add(state, state_zeta_power_n_minus_one)) + + let li := mPtr + batch_compute_lagranges_at_z(z, zpnmo, n, li) + + let tmp := 0 + for {let i:=0} lt(i,n) {i:=add(i,1)} + { + tmp := mulmod(mload(li), calldataload(ins), r_mod) + pi_wo_commit := addmod(pi_wo_commit, tmp, r_mod) + li := add(li, 0x20) + ins := add(ins, 0x20) + } + + } + + // mPtr <- [L_0(z), .., L_{n-1}(z)] + // + // Here L_i(zeta) = ωⁱ/n * (ζⁿ-1)/(ζ-ωⁱ) where: + // * n = vk_domain_size + // * ω = vk_omega (generator of the multiplicative cyclic group of order n in (ℤ/rℤ)*) + // * ζ = z (challenge derived with Fiat Shamir) + // * zpnmo = 'zeta power n minus one' (ζⁿ-1) which has been precomputed + function batch_compute_lagranges_at_z(z, zpnmo, n, mPtr) { + + let zn := mulmod(zpnmo, vk_inv_domain_size, r_mod) // 1/n * (ζⁿ - 1) + + let _w := 1 + let _mPtr := mPtr + for {let i:=0} lt(i,n) {i:=add(i,1)} + { + mstore(_mPtr, addmod(z,sub(r_mod, _w), r_mod)) + _w := mulmod(_w, vk_omega, r_mod) + _mPtr := add(_mPtr, 0x20) + } + batch_invert(mPtr, n, _mPtr) + _mPtr := mPtr + _w := 1 + for {let i:=0} lt(i,n) {i:=add(i,1)} + { + mstore(_mPtr, mulmod(mulmod(mload(_mPtr), zn , r_mod), _w, r_mod)) + _mPtr := add(_mPtr, 0x20) + _w := mulmod(_w, vk_omega, r_mod) + } + } + + // batch invert (modulo r) in place the nb_ins uint256 inputs starting at ins. + function batch_invert(ins, nb_ins, mPtr) { + mstore(mPtr, 1) + let offset := 0 + for {let i:=0} lt(i, nb_ins) {i:=add(i,1)} + { + let prev := mload(add(mPtr, offset)) + let cur := mload(add(ins, offset)) + cur := mulmod(prev, cur, r_mod) + offset := add(offset, 0x20) + mstore(add(mPtr, offset), cur) + } + ins := add(ins, sub(offset, 0x20)) + mPtr := add(mPtr, offset) + let inv := pow(mload(mPtr), sub(r_mod,2), add(mPtr, 0x20)) + for {let i:=0} lt(i, nb_ins) {i:=add(i,1)} + { + mPtr := sub(mPtr, 0x20) + let tmp := mload(ins) + let cur := mulmod(inv, mload(mPtr), r_mod) + mstore(ins, cur) + inv := mulmod(inv, tmp, r_mod) + ins := sub(ins, 0x20) + } + } + + {{ if (gt (len .CommitmentConstraintIndexes) 0 )}} + // mPtr free memory. Computes the public input contribution related to the commit + function sum_pi_commit(aproof, nb_public_inputs, mPtr)->pi_commit { + + let state := mload(0x40) + let z := mload(add(state, state_zeta)) + let zpnmo := mload(add(state, state_zeta_power_n_minus_one)) + + let p := add(aproof, proof_openings_selector_commit_api_at_zeta) + p := add(p, mul(vk_nb_commitments_commit_api, 0x20)) // p points now to the wire commitments + + let h_fr, ith_lagrange + + {{ range $index, $element := .CommitmentConstraintIndexes}} + h_fr := hash_fr(calldataload(p), calldataload(add(p, 0x20)), mPtr) + ith_lagrange := compute_ith_lagrange_at_z(z, zpnmo, add(nb_public_inputs, vk_index_commit_api_{{ $index }}), mPtr) + pi_commit := addmod(pi_commit, mulmod(h_fr, ith_lagrange, r_mod), r_mod) + p := add(p, 0x40) + {{ end }} + + } + + // z zeta + // zpmno ζⁿ-1 + // i i-th lagrange + // mPtr free memory + // Computes L_i(zeta) = ωⁱ/n * (ζⁿ-1)/(ζ-ωⁱ) where: + function compute_ith_lagrange_at_z(z, zpnmo, i, mPtr)->res { + + let w := pow(vk_omega, i, mPtr) // w**i + i := addmod(z, sub(r_mod, w), r_mod) // z-w**i + w := mulmod(w, vk_inv_domain_size, r_mod) // w**i/n + i := pow(i, sub(r_mod,2), mPtr) // (z-w**i)**-1 + w := mulmod(w, i, r_mod) // w**i/n*(z-w)**-1 + res := mulmod(w, zpnmo, r_mod) + + } + + // (x, y) point on bn254, both on 32bytes + // mPtr free memory + function hash_fr(x, y, mPtr)->res { + + // [0x00, .. , 0x00 || x, y, || 0, 48, 0, dst, sizeDomain] + // <- 64 bytes -> <-64b -> <- 1 bytes each -> + + // [0x00, .., 0x00] 64 bytes of zero + mstore(mPtr, zero_uint256) + mstore(add(mPtr, 0x20), zero_uint256) + + // msg = x || y , both on 32 bytes + mstore(add(mPtr, 0x40), x) + mstore(add(mPtr, 0x60), y) + + // 0 || 48 || 0 all on 1 byte + mstore8(add(mPtr, 0x80), 0) + mstore8(add(mPtr, 0x81), lenInBytes) + mstore8(add(mPtr, 0x82), 0) + + // "BSB22-Plonk" = [42, 53, 42, 32, 32, 2d, 50, 6c, 6f, 6e, 6b,] + mstore8(add(mPtr, 0x83), 0x42) + mstore8(add(mPtr, 0x84), 0x53) + mstore8(add(mPtr, 0x85), 0x42) + mstore8(add(mPtr, 0x86), 0x32) + mstore8(add(mPtr, 0x87), 0x32) + mstore8(add(mPtr, 0x88), 0x2d) + mstore8(add(mPtr, 0x89), 0x50) + mstore8(add(mPtr, 0x8a), 0x6c) + mstore8(add(mPtr, 0x8b), 0x6f) + mstore8(add(mPtr, 0x8c), 0x6e) + mstore8(add(mPtr, 0x8d), 0x6b) + + // size domain + mstore8(add(mPtr, 0x8e), sizeDomain) + + let l_success := staticcall(gas(), 0x2, mPtr, 0x8f, mPtr, 0x20) + if iszero(l_success) { + error_verify() + } + + let b0 := mload(mPtr) + + // [b0 || one || dst || sizeDomain] + // <-64bytes -> <- 1 byte each -> + mstore8(add(mPtr, 0x20), one) // 1 + + mstore8(add(mPtr, 0x21), 0x42) // dst + mstore8(add(mPtr, 0x22), 0x53) + mstore8(add(mPtr, 0x23), 0x42) + mstore8(add(mPtr, 0x24), 0x32) + mstore8(add(mPtr, 0x25), 0x32) + mstore8(add(mPtr, 0x26), 0x2d) + mstore8(add(mPtr, 0x27), 0x50) + mstore8(add(mPtr, 0x28), 0x6c) + mstore8(add(mPtr, 0x29), 0x6f) + mstore8(add(mPtr, 0x2a), 0x6e) + mstore8(add(mPtr, 0x2b), 0x6b) + + mstore8(add(mPtr, 0x2c), sizeDomain) // size domain + l_success := staticcall(gas(), 0x2, mPtr, 0x2d, mPtr, 0x20) + if iszero(l_success) { + error_verify() + } + + // b1 is located at mPtr. We store b2 at add(mPtr, 0x20) + + // [b0^b1 || two || dst || sizeDomain] + // <-64bytes -> <- 1 byte each -> + mstore(add(mPtr, 0x20), xor(mload(mPtr), b0)) + mstore8(add(mPtr, 0x40), two) + + mstore8(add(mPtr, 0x41), 0x42) // dst + mstore8(add(mPtr, 0x42), 0x53) + mstore8(add(mPtr, 0x43), 0x42) + mstore8(add(mPtr, 0x44), 0x32) + mstore8(add(mPtr, 0x45), 0x32) + mstore8(add(mPtr, 0x46), 0x2d) + mstore8(add(mPtr, 0x47), 0x50) + mstore8(add(mPtr, 0x48), 0x6c) + mstore8(add(mPtr, 0x49), 0x6f) + mstore8(add(mPtr, 0x4a), 0x6e) + mstore8(add(mPtr, 0x4b), 0x6b) + + mstore8(add(mPtr, 0x4c), sizeDomain) // size domain + + let offset := add(mPtr, 0x20) + l_success := staticcall(gas(), 0x2, offset, 0x2d, offset, 0x20) + if iszero(l_success) { + error_verify() + } + + // at this point we have mPtr = [ b1 || b2] where b1 is on 32byes and b2 in 16bytes. + // we interpret it as a big integer mod r in big endian (similar to regular decimal notation) + // the result is then 2**(8*16)*mPtr[32:] + mPtr[32:48] + res := mulmod(mload(mPtr), bb, r_mod) // <- res = 2**128 * mPtr[:32] + offset := add(mPtr, 0x10) + for {let i:=0} lt(i, 0x10) {i:=add(i,1)} // mPtr <- [xx, xx, .., | 0, 0, .. 0 || b2 ] + { + mstore8(offset, 0x00) + offset := add(offset, 0x1) + } + let b1 := mload(add(mPtr, 0x10)) // b1 <- [0, 0, .., 0 || b2[:16] ] + res := addmod(res, b1, r_mod) + + } + {{ end }} + // END compute_pi ------------------------------------------------- + + // compute α² * 1/n * (ζ{n}-1)/(ζ - 1) where + // * α = challenge derived in derive_gamma_beta_alpha_zeta + // * n = vk_domain_size + // * ω = vk_omega (generator of the multiplicative cyclic group of order n in (ℤ/rℤ)*) + // * ζ = zeta (challenge derived with Fiat Shamir) + function compute_alpha_square_lagrange_0() { + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + + let res := mload(add(state, state_zeta_power_n_minus_one)) + let den := addmod(mload(add(state, state_zeta)), sub(r_mod, 1), r_mod) + den := pow(den, sub(r_mod, 2), mPtr) + den := mulmod(den, vk_inv_domain_size, r_mod) + res := mulmod(den, res, r_mod) + + let l_alpha := mload(add(state, state_alpha)) + res := mulmod(res, l_alpha, r_mod) + res := mulmod(res, l_alpha, r_mod) + mstore(add(state, state_alpha_square_lagrange_0), res) + } + + // follows alg. p.13 of https://eprint.iacr.org/2019/953.pdf + // with t₁ = t₂ = 1, and the proofs are ([digest] + [quotient] +purported evaluation): + // * [state_folded_state_digests], [proof_batch_opening_at_zeta_x], state_folded_evals + // * [proof_grand_product_commitment], [proof_opening_at_zeta_omega_x], [proof_grand_product_at_zeta_omega] + function batch_verify_multi_points(aproof) { + let state := mload(0x40) + let mPtr := add(state, state_last_mem) + + // here the random is not a challenge, hence no need to use Fiat Shamir, we just + // need an unpredictible result. + let random := mod(keccak256(state, 0x20), r_mod) + + let folded_quotients := mPtr + mPtr := add(folded_quotients, 0x40) + mstore(folded_quotients, calldataload(add(aproof, proof_batch_opening_at_zeta_x))) + mstore(add(folded_quotients, 0x20), calldataload(add(aproof, proof_batch_opening_at_zeta_y))) + point_acc_mul_calldata(folded_quotients, add(aproof, proof_opening_at_zeta_omega_x), random, mPtr) + + let folded_digests := add(state, state_folded_digests_x) + point_acc_mul_calldata(folded_digests, add(aproof, proof_grand_product_commitment_x), random, mPtr) + + let folded_evals := add(state, state_folded_claimed_values) + fr_acc_mul_calldata(folded_evals, add(aproof, proof_grand_product_at_zeta_omega), random) + + let folded_evals_commit := mPtr + mPtr := add(folded_evals_commit, 0x40) + mstore(folded_evals_commit, {{ fpstr .Kzg.G1.X }}) + mstore(add(folded_evals_commit, 0x20), {{ fpstr .Kzg.G1.Y }}) + mstore(add(folded_evals_commit, 0x40), mload(folded_evals)) + let check_staticcall := staticcall(gas(), 7, folded_evals_commit, 0x60, folded_evals_commit, 0x40) + if eq(check_staticcall, 0) { + error_verify() + } + + let folded_evals_commit_y := add(folded_evals_commit, 0x20) + mstore(folded_evals_commit_y, sub(p_mod, mload(folded_evals_commit_y))) + point_add(folded_digests, folded_digests, folded_evals_commit, mPtr) + + let folded_points_quotients := mPtr + mPtr := add(mPtr, 0x40) + point_mul_calldata( + folded_points_quotients, + add(aproof, proof_batch_opening_at_zeta_x), + mload(add(state, state_zeta)), + mPtr + ) + let zeta_omega := mulmod(mload(add(state, state_zeta)), vk_omega, r_mod) + random := mulmod(random, zeta_omega, r_mod) + point_acc_mul_calldata(folded_points_quotients, add(aproof, proof_opening_at_zeta_omega_x), random, mPtr) + + point_add(folded_digests, folded_digests, folded_points_quotients, mPtr) + + let folded_quotients_y := add(folded_quotients, 0x20) + mstore(folded_quotients_y, sub(p_mod, mload(folded_quotients_y))) + + mstore(mPtr, mload(folded_digests)) + mstore(add(mPtr, 0x20), mload(add(folded_digests, 0x20))) + mstore(add(mPtr, 0x40), g2_srs_0_x_0) // the 4 lines are the canonical G2 point on BN254 + mstore(add(mPtr, 0x60), g2_srs_0_x_1) + mstore(add(mPtr, 0x80), g2_srs_0_y_0) + mstore(add(mPtr, 0xa0), g2_srs_0_y_1) + mstore(add(mPtr, 0xc0), mload(folded_quotients)) + mstore(add(mPtr, 0xe0), mload(add(folded_quotients, 0x20))) + mstore(add(mPtr, 0x100), g2_srs_1_x_0) + mstore(add(mPtr, 0x120), g2_srs_1_x_1) + mstore(add(mPtr, 0x140), g2_srs_1_y_0) + mstore(add(mPtr, 0x160), g2_srs_1_y_1) + check_pairing_kzg(mPtr) + } + + // check_pairing_kzg checks the result of the final pairing product of the batched + // kzg verification. The purpose of this function is too avoid exhausting the stack + // in the function batch_verify_multi_points. + // mPtr: pointer storing the tuple of pairs + function check_pairing_kzg(mPtr) { + let state := mload(0x40) + + // TODO test the staticcall using the method from audit_4-5 + let l_success := staticcall(gas(), 8, mPtr, 0x180, 0x00, 0x20) + let res_pairing := mload(0x00) + let s_success := mload(add(state, state_success)) + res_pairing := and(and(res_pairing, l_success), s_success) + mstore(add(state, state_success), res_pairing) + } + + // Fold the opening proofs at ζ: + // * at state+state_folded_digest we store: [H] + γ[Linearised_polynomial]+γ²[L] + γ³[R] + γ⁴[O] + γ⁵[S₁] +γ⁶[S₂] + ∑ᵢγ⁶⁺ⁱ[Pi_{i}] + // * at state+state_folded_claimed_values we store: H(ζ) + γLinearised_polynomial(ζ)+γ²L(ζ) + γ³R(ζ)+ γ⁴O(ζ) + γ⁵S₁(ζ) +γ⁶S₂(ζ) + ∑ᵢγ⁶⁺ⁱPi_{i}(ζ) + // acc_gamma stores the γⁱ + function fold_state(aproof) { + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + + let l_gamma_kzg := mload(add(state, state_gamma_kzg)) + let acc_gamma := l_gamma_kzg + + let offset := add(0x200, mul(vk_nb_commitments_commit_api, 0x40)) // 0x40 = 2*0x20 + let mPtrOffset := add(mPtr, offset) + + mstore(add(state, state_folded_digests_x), mload(add(mPtr, 0x40))) + mstore(add(state, state_folded_digests_y), mload(add(mPtr, 0x60))) + mstore(add(state, state_folded_claimed_values), calldataload(add(aproof, proof_quotient_polynomial_at_zeta))) + + point_acc_mul(add(state, state_folded_digests_x), add(mPtr, 0x80), acc_gamma, mPtrOffset) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), add(aproof, proof_linearised_polynomial_at_zeta), acc_gamma) + + acc_gamma := mulmod(acc_gamma, l_gamma_kzg, r_mod) + point_acc_mul(add(state, state_folded_digests_x), add(mPtr, 0xc0), acc_gamma, mPtrOffset) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), add(aproof, proof_l_at_zeta), acc_gamma) + + acc_gamma := mulmod(acc_gamma, l_gamma_kzg, r_mod) + point_acc_mul(add(state, state_folded_digests_x), add(mPtr, 0x100), acc_gamma, add(mPtr, offset)) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), add(aproof, proof_r_at_zeta), acc_gamma) + + acc_gamma := mulmod(acc_gamma, l_gamma_kzg, r_mod) + point_acc_mul(add(state, state_folded_digests_x), add(mPtr, 0x140), acc_gamma, add(mPtr, offset)) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), add(aproof, proof_o_at_zeta), acc_gamma) + + acc_gamma := mulmod(acc_gamma, l_gamma_kzg, r_mod) + point_acc_mul(add(state, state_folded_digests_x), add(mPtr, 0x180), acc_gamma, add(mPtr, offset)) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), add(aproof, proof_s1_at_zeta), acc_gamma) + + acc_gamma := mulmod(acc_gamma, l_gamma_kzg, r_mod) + point_acc_mul(add(state, state_folded_digests_x), add(mPtr, 0x1c0), acc_gamma, add(mPtr, offset)) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), add(aproof, proof_s2_at_zeta), acc_gamma) + + let poscaz := add(aproof, proof_openings_selector_commit_api_at_zeta) + let opca := add(mPtr, 0x200) // offset_proof_commits_api + for {let i := 0} lt(i, vk_nb_commitments_commit_api) {i := add(i, 1)} + { + acc_gamma := mulmod(acc_gamma, l_gamma_kzg, r_mod) + point_acc_mul(add(state, state_folded_digests_x), opca, acc_gamma, add(mPtr, offset)) + fr_acc_mul_calldata(add(state, state_folded_claimed_values), poscaz, acc_gamma) + poscaz := add(poscaz, 0x20) + opca := add(opca, 0x40) + } + } + + // generate the challenge (using Fiat Shamir) to fold the opening proofs + // at ζ. + // The process for deriving γ is the same as in derive_gamma but this time the inputs are + // in this order (the [] means it's a commitment): + // * ζ + // * [H] ( = H₁ + ζᵐ⁺²*H₂ + ζ²⁽ᵐ⁺²⁾*H₃ ) + // * [Linearised polynomial] + // * [L], [R], [O] + // * [S₁] [S₂] + // * [Pi_{i}] (wires associated to custom gates) + // Then there are the purported evaluations of the previous committed polynomials: + // * H(ζ) + // * Linearised_polynomial(ζ) + // * L(ζ), R(ζ), O(ζ), S₁(ζ), S₂(ζ) + // * Pi_{i}(ζ) + function compute_gamma_kzg(aproof) { + + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + mstore(mPtr, 0x67616d6d61) // "gamma" + mstore(add(mPtr, 0x20), mload(add(state, state_zeta))) + mstore(add(mPtr,0x40), mload(add(state, state_folded_h_x))) + mstore(add(mPtr,0x60), mload(add(state, state_folded_h_y))) + mstore(add(mPtr,0x80), mload(add(state, state_linearised_polynomial_x))) + mstore(add(mPtr,0xa0), mload(add(state, state_linearised_polynomial_y))) + calldatacopy(add(mPtr, 0xc0), add(aproof, proof_l_com_x), 0xc0) + mstore(add(mPtr,0x180), vk_s1_com_x) + mstore(add(mPtr,0x1a0), vk_s1_com_y) + mstore(add(mPtr,0x1c0), vk_s2_com_x) + mstore(add(mPtr,0x1e0), vk_s2_com_y) + + let offset := 0x200 + {{ range $index, $element := .CommitmentConstraintIndexes }} + mstore(add(mPtr,offset), vk_selector_commitments_commit_api_{{ $index }}_x) + mstore(add(mPtr,add(offset, 0x20)), vk_selector_commitments_commit_api_{{ $index }}_y) + offset := add(offset, 0x40) + {{ end }} + + mstore(add(mPtr, offset), calldataload(add(aproof, proof_quotient_polynomial_at_zeta))) + mstore(add(mPtr, add(offset, 0x20)), calldataload(add(aproof, proof_linearised_polynomial_at_zeta))) + mstore(add(mPtr, add(offset, 0x40)), calldataload(add(aproof, proof_l_at_zeta))) + mstore(add(mPtr, add(offset, 0x60)), calldataload(add(aproof, proof_r_at_zeta))) + mstore(add(mPtr, add(offset, 0x80)), calldataload(add(aproof, proof_o_at_zeta))) + mstore(add(mPtr, add(offset, 0xa0)), calldataload(add(aproof, proof_s1_at_zeta))) + mstore(add(mPtr, add(offset, 0xc0)), calldataload(add(aproof, proof_s2_at_zeta))) + + {{ if (gt (len .CommitmentConstraintIndexes) 0 )}} + let _mPtr := add(mPtr, add(offset, 0xe0)) + let _poscaz := add(aproof, proof_openings_selector_commit_api_at_zeta) + for {let i:=0} lt(i, vk_nb_commitments_commit_api) {i:=add(i,1)} + { + mstore(_mPtr, calldataload(_poscaz)) + _poscaz := add(_poscaz, 0x20) + _mPtr := add(_mPtr, 0x20) + } + {{ end }} + + let start_input := 0x1b // 00.."gamma" + let size_input := add(0x16, mul(vk_nb_commitments_commit_api,3)) // number of 32bytes elmts = 0x16 (zeta+2*7+7 for the digests+openings) + 2*vk_nb_commitments_commit_api (for the commitments of the selectors) + vk_nb_commitments_commit_api (for the openings of the selectors) + size_input := add(0x5, mul(size_input, 0x20)) // size in bytes: 15*32 bytes + 5 bytes for gamma + let check_staticcall := staticcall(gas(), 0x2, add(mPtr,start_input), size_input, add(state, state_gamma_kzg), 0x20) + if eq(check_staticcall, 0) { + error_verify() + } + mstore(add(state, state_gamma_kzg), mod(mload(add(state, state_gamma_kzg)), r_mod)) + } + + function compute_commitment_linearised_polynomial_ec(aproof, s1, s2) { + let state := mload(0x40) + let mPtr := add(mload(0x40), state_last_mem) + + mstore(mPtr, vk_ql_com_x) + mstore(add(mPtr, 0x20), vk_ql_com_y) + point_mul( + add(state, state_linearised_polynomial_x), + mPtr, + calldataload(add(aproof, proof_l_at_zeta)), + add(mPtr, 0x40) + ) + + mstore(mPtr, vk_qr_com_x) + mstore(add(mPtr, 0x20), vk_qr_com_y) + point_acc_mul( + add(state, state_linearised_polynomial_x), + mPtr, + calldataload(add(aproof, proof_r_at_zeta)), + add(mPtr, 0x40) + ) + + let rl := mulmod(calldataload(add(aproof, proof_l_at_zeta)), calldataload(add(aproof, proof_r_at_zeta)), r_mod) + mstore(mPtr, vk_qm_com_x) + mstore(add(mPtr, 0x20), vk_qm_com_y) + point_acc_mul(add(state, state_linearised_polynomial_x), mPtr, rl, add(mPtr, 0x40)) + + mstore(mPtr, vk_qo_com_x) + mstore(add(mPtr, 0x20), vk_qo_com_y) + point_acc_mul( + add(state, state_linearised_polynomial_x), + mPtr, + calldataload(add(aproof, proof_o_at_zeta)), + add(mPtr, 0x40) + ) + + mstore(mPtr, vk_qk_com_x) + mstore(add(mPtr, 0x20), vk_qk_com_y) + point_add( + add(state, state_linearised_polynomial_x), + add(state, state_linearised_polynomial_x), + mPtr, + add(mPtr, 0x40) + ) + + let commits_api_at_zeta := add(aproof, proof_openings_selector_commit_api_at_zeta) + let commits_api := add( + aproof, + add(proof_openings_selector_commit_api_at_zeta, mul(vk_nb_commitments_commit_api, 0x20)) + ) + for { + let i := 0 + } lt(i, vk_nb_commitments_commit_api) { + i := add(i, 1) + } { + mstore(mPtr, calldataload(commits_api)) + mstore(add(mPtr, 0x20), calldataload(add(commits_api, 0x20))) + point_acc_mul( + add(state, state_linearised_polynomial_x), + mPtr, + calldataload(commits_api_at_zeta), + add(mPtr, 0x40) + ) + commits_api_at_zeta := add(commits_api_at_zeta, 0x20) + commits_api := add(commits_api, 0x40) + } + + mstore(mPtr, vk_s3_com_x) + mstore(add(mPtr, 0x20), vk_s3_com_y) + point_acc_mul(add(state, state_linearised_polynomial_x), mPtr, s1, add(mPtr, 0x40)) + + mstore(mPtr, calldataload(add(aproof, proof_grand_product_commitment_x))) + mstore(add(mPtr, 0x20), calldataload(add(aproof, proof_grand_product_commitment_y))) + point_acc_mul(add(state, state_linearised_polynomial_x), mPtr, s2, add(mPtr, 0x40)) + } + + // Compute the commitment to the linearized polynomial equal to + // L(ζ)[Qₗ]+r(ζ)[Qᵣ]+R(ζ)L(ζ)[Qₘ]+O(ζ)[Qₒ]+[Qₖ]+Σᵢqc'ᵢ(ζ)[BsbCommitmentᵢ] + + // α*( Z(μζ)(L(ζ)+β*S₁(ζ)+γ)*(R(ζ)+β*S₂(ζ)+γ)[S₃]-[Z](L(ζ)+β*id_{1}(ζ)+γ)*(R(ζ)+β*id_{2(ζ)+γ)*(O(ζ)+β*id_{3}(ζ)+γ) ) + + // α²*L₁(ζ)[Z] + // where + // * id_1 = id, id_2 = vk_coset_shift*id, id_3 = vk_coset_shift^{2}*id + // * the [] means that it's a commitment (i.e. a point on Bn254(F_p)) + function compute_commitment_linearised_polynomial(aproof) { + let state := mload(0x40) + let l_beta := mload(add(state, state_beta)) + let l_gamma := mload(add(state, state_gamma)) + let l_zeta := mload(add(state, state_zeta)) + let l_alpha := mload(add(state, state_alpha)) + + let u := mulmod(calldataload(add(aproof, proof_grand_product_at_zeta_omega)), l_beta, r_mod) + let v := mulmod(l_beta, calldataload(add(aproof, proof_s1_at_zeta)), r_mod) + v := addmod(v, calldataload(add(aproof, proof_l_at_zeta)), r_mod) + v := addmod(v, l_gamma, r_mod) + + let w := mulmod(l_beta, calldataload(add(aproof, proof_s2_at_zeta)), r_mod) + w := addmod(w, calldataload(add(aproof, proof_r_at_zeta)), r_mod) + w := addmod(w, l_gamma, r_mod) + + let s1 := mulmod(u, v, r_mod) + s1 := mulmod(s1, w, r_mod) + s1 := mulmod(s1, l_alpha, r_mod) + + let coset_square := mulmod(vk_coset_shift, vk_coset_shift, r_mod) + let betazeta := mulmod(l_beta, l_zeta, r_mod) + u := addmod(betazeta, calldataload(add(aproof, proof_l_at_zeta)), r_mod) + u := addmod(u, l_gamma, r_mod) + + v := mulmod(betazeta, vk_coset_shift, r_mod) + v := addmod(v, calldataload(add(aproof, proof_r_at_zeta)), r_mod) + v := addmod(v, l_gamma, r_mod) + + w := mulmod(betazeta, coset_square, r_mod) + w := addmod(w, calldataload(add(aproof, proof_o_at_zeta)), r_mod) + w := addmod(w, l_gamma, r_mod) + + let s2 := mulmod(u, v, r_mod) + s2 := mulmod(s2, w, r_mod) + s2 := sub(r_mod, s2) + s2 := mulmod(s2, l_alpha, r_mod) + s2 := addmod(s2, mload(add(state, state_alpha_square_lagrange_0)), r_mod) + + // at this stage: + // * s₁ = α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β + // * s₂ = -α*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + α²*L₁(ζ) + + compute_commitment_linearised_polynomial_ec(aproof, s1, s2) + } + + // compute H₁ + ζᵐ⁺²*H₂ + ζ²⁽ᵐ⁺²⁾*H₃ and store the result at + // state + state_folded_h + function fold_h(aproof) { + let state := mload(0x40) + let n_plus_two := add(vk_domain_size, 2) + let mPtr := add(mload(0x40), state_last_mem) + let zeta_power_n_plus_two := pow(mload(add(state, state_zeta)), n_plus_two, mPtr) + point_mul_calldata(add(state, state_folded_h_x), add(aproof, proof_h_2_x), zeta_power_n_plus_two, mPtr) + point_add_calldata(add(state, state_folded_h_x), add(state, state_folded_h_x), add(aproof, proof_h_1_x), mPtr) + point_mul(add(state, state_folded_h_x), add(state, state_folded_h_x), zeta_power_n_plus_two, mPtr) + point_add_calldata(add(state, state_folded_h_x), add(state, state_folded_h_x), add(aproof, proof_h_0_x), mPtr) + } + + // check that + // L(ζ)Qₗ(ζ)+r(ζ)Qᵣ(ζ)+R(ζ)L(ζ)Qₘ(ζ)+O(ζ)Qₒ(ζ)+Qₖ(ζ)+Σᵢqc'ᵢ(ζ)BsbCommitmentᵢ(ζ) + + // α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) + // + α²*L₁(ζ) = + // (ζⁿ-1)H(ζ) + function verify_quotient_poly_eval_at_zeta(aproof) { + let state := mload(0x40) + + // (l(ζ)+β*s1(ζ)+γ) + let s1 := add(mload(0x40), state_last_mem) + mstore(s1, mulmod(calldataload(add(aproof, proof_s1_at_zeta)), mload(add(state, state_beta)), r_mod)) + mstore(s1, addmod(mload(s1), mload(add(state, state_gamma)), r_mod)) + mstore(s1, addmod(mload(s1), calldataload(add(aproof, proof_l_at_zeta)), r_mod)) + + // (r(ζ)+β*s2(ζ)+γ) + let s2 := add(s1, 0x20) + mstore(s2, mulmod(calldataload(add(aproof, proof_s2_at_zeta)), mload(add(state, state_beta)), r_mod)) + mstore(s2, addmod(mload(s2), mload(add(state, state_gamma)), r_mod)) + mstore(s2, addmod(mload(s2), calldataload(add(aproof, proof_r_at_zeta)), r_mod)) + // _s2 := mload(s2) + + // (o(ζ)+γ) + let o := add(s1, 0x40) + mstore(o, addmod(calldataload(add(aproof, proof_o_at_zeta)), mload(add(state, state_gamma)), r_mod)) + + // α*(Z(μζ))*(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*(o(ζ)+γ) + mstore(s1, mulmod(mload(s1), mload(s2), r_mod)) + mstore(s1, mulmod(mload(s1), mload(o), r_mod)) + mstore(s1, mulmod(mload(s1), mload(add(state, state_alpha)), r_mod)) + mstore(s1, mulmod(mload(s1), calldataload(add(aproof, proof_grand_product_at_zeta_omega)), r_mod)) + + let computed_quotient := add(s1, 0x60) + + // linearizedpolynomial + pi(zeta) + mstore(computed_quotient,addmod(calldataload(add(aproof, proof_linearised_polynomial_at_zeta)), mload(add(state, state_pi)), r_mod)) + mstore(computed_quotient, addmod(mload(computed_quotient), mload(s1), r_mod)) + mstore(computed_quotient,addmod(mload(computed_quotient), sub(r_mod, mload(add(state, state_alpha_square_lagrange_0))), r_mod)) + mstore(s2,mulmod(calldataload(add(aproof, proof_quotient_polynomial_at_zeta)),mload(add(state, state_zeta_power_n_minus_one)),r_mod)) + + mstore(add(state, state_success), eq(mload(computed_quotient), mload(s2))) + } + + // BEGINNING utils math functions ------------------------------------------------- + function point_add(dst, p, q, mPtr) { + let state := mload(0x40) + mstore(mPtr, mload(p)) + mstore(add(mPtr, 0x20), mload(add(p, 0x20))) + mstore(add(mPtr, 0x40), mload(q)) + mstore(add(mPtr, 0x60), mload(add(q, 0x20))) + let l_success := staticcall(gas(),6,mPtr,0x80,dst,0x40) + if iszero(l_success) { + error_ec_op() + } + } + + function point_add_calldata(dst, p, q, mPtr) { + let state := mload(0x40) + mstore(mPtr, mload(p)) + mstore(add(mPtr, 0x20), mload(add(p, 0x20))) + mstore(add(mPtr, 0x40), calldataload(q)) + mstore(add(mPtr, 0x60), calldataload(add(q, 0x20))) + let l_success := staticcall(gas(), 6, mPtr, 0x80, dst, 0x40) + if iszero(l_success) { + error_ec_op() + } + } + + // dst <- [s]src + function point_mul(dst,src,s, mPtr) { + let state := mload(0x40) + mstore(mPtr,mload(src)) + mstore(add(mPtr,0x20),mload(add(src,0x20))) + mstore(add(mPtr,0x40),s) + let l_success := staticcall(gas(),7,mPtr,0x60,dst,0x40) + if iszero(l_success) { + error_ec_op() + } + } + + // dst <- [s]src + function point_mul_calldata(dst, src, s, mPtr) { + let state := mload(0x40) + mstore(mPtr, calldataload(src)) + mstore(add(mPtr, 0x20), calldataload(add(src, 0x20))) + mstore(add(mPtr, 0x40), s) + let l_success := staticcall(gas(), 7, mPtr, 0x60, dst, 0x40) + if iszero(l_success) { + error_ec_op() + } + } + + // dst <- dst + [s]src (Elliptic curve) + function point_acc_mul(dst,src,s, mPtr) { + let state := mload(0x40) + mstore(mPtr,mload(src)) + mstore(add(mPtr,0x20),mload(add(src,0x20))) + mstore(add(mPtr,0x40),s) + let l_success := staticcall(gas(),7,mPtr,0x60,mPtr,0x40) + mstore(add(mPtr,0x40),mload(dst)) + mstore(add(mPtr,0x60),mload(add(dst,0x20))) + l_success := and(l_success, staticcall(gas(),6,mPtr,0x80,dst, 0x40)) + if iszero(l_success) { + error_ec_op() + } + } + + // dst <- dst + [s]src (Elliptic curve) + function point_acc_mul_calldata(dst, src, s, mPtr) { + let state := mload(0x40) + mstore(mPtr, calldataload(src)) + mstore(add(mPtr, 0x20), calldataload(add(src, 0x20))) + mstore(add(mPtr, 0x40), s) + let l_success := staticcall(gas(), 7, mPtr, 0x60, mPtr, 0x40) + mstore(add(mPtr, 0x40), mload(dst)) + mstore(add(mPtr, 0x60), mload(add(dst, 0x20))) + l_success := and(l_success, staticcall(gas(), 6, mPtr, 0x80, dst, 0x40)) + if iszero(l_success) { + error_ec_op() + } + } + + // dst <- dst + src (Fr) dst,src are addresses, s is a value + function fr_acc_mul_calldata(dst, src, s) { + let tmp := mulmod(calldataload(src), s, r_mod) + mstore(dst, addmod(mload(dst), tmp, r_mod)) + } + + // dst <- x ** e mod r (x, e are values, not pointers) + function pow(x, e, mPtr)->res { + mstore(mPtr, 0x20) + mstore(add(mPtr, 0x20), 0x20) + mstore(add(mPtr, 0x40), 0x20) + mstore(add(mPtr, 0x60), x) + mstore(add(mPtr, 0x80), e) + mstore(add(mPtr, 0xa0), r_mod) + let check_staticcall := staticcall(gas(),0x05,mPtr,0xc0,mPtr,0x20) + if eq(check_staticcall, 0) { + error_verify() + } + res := mload(mPtr) + } + } + } +} +` + +// MarshalSolidity converts a proof to a byte array that can be used in a +// Solidity contract. +func (proof *Proof) MarshalSolidity() []byte { + + res := make([]byte, 0, 1024) + + // uint256 l_com_x; + // uint256 l_com_y; + // uint256 r_com_x; + // uint256 r_com_y; + // uint256 o_com_x; + // uint256 o_com_y; + var tmp64 [64]byte + for i := 0; i < 3; i++ { + tmp64 = proof.LRO[i].RawBytes() + res = append(res, tmp64[:]...) + } + + // uint256 h_0_x; + // uint256 h_0_y; + // uint256 h_1_x; + // uint256 h_1_y; + // uint256 h_2_x; + // uint256 h_2_y; + for i := 0; i < 3; i++ { + tmp64 = proof.H[i].RawBytes() + res = append(res, tmp64[:]...) + } + var tmp32 [32]byte + + // uint256 l_at_zeta; + // uint256 r_at_zeta; + // uint256 o_at_zeta; + // uint256 s1_at_zeta; + // uint256 s2_at_zeta; + for i := 2; i < 7; i++ { + tmp32 = proof.BatchedProof.ClaimedValues[i].Bytes() + res = append(res, tmp32[:]...) + } + + // uint256 grand_product_commitment_x; + // uint256 grand_product_commitment_y; + tmp64 = proof.Z.RawBytes() + res = append(res, tmp64[:]...) + + // uint256 grand_product_at_zeta_omega; + tmp32 = proof.ZShiftedOpening.ClaimedValue.Bytes() + res = append(res, tmp32[:]...) + + // uint256 quotient_polynomial_at_zeta; + // uint256 linearization_polynomial_at_zeta; + tmp32 = proof.BatchedProof.ClaimedValues[0].Bytes() + res = append(res, tmp32[:]...) + tmp32 = proof.BatchedProof.ClaimedValues[1].Bytes() + res = append(res, tmp32[:]...) + + // uint256 opening_at_zeta_proof_x; + // uint256 opening_at_zeta_proof_y; + tmp64 = proof.BatchedProof.H.RawBytes() + res = append(res, tmp64[:]...) + + // uint256 opening_at_zeta_omega_proof_x; + // uint256 opening_at_zeta_omega_proof_y; + tmp64 = proof.ZShiftedOpening.H.RawBytes() + res = append(res, tmp64[:]...) + + // uint256[] selector_commit_api_at_zeta; + // uint256[] wire_committed_commitments; + if len(proof.Bsb22Commitments) > 0 { + for i := 0; i < len(proof.Bsb22Commitments); i++ { + tmp32 = proof.BatchedProof.ClaimedValues[7+i].Bytes() + res = append(res, tmp32[:]...) + } + + for _, bc := range proof.Bsb22Commitments { + tmp64 = bc.RawBytes() + res = append(res, tmp64[:]...) + } + } + + return res +} diff --git a/internal/backend/bn254/plonk/verify.go b/backend/plonk/bn254/verify.go similarity index 73% rename from internal/backend/bn254/plonk/verify.go rename to backend/plonk/bn254/verify.go index 74b56c5caf..88bd903f5e 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/backend/plonk/bn254/verify.go @@ -25,6 +25,8 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" curve "github.com/consensys/gnark-crypto/ecc/bn254" @@ -53,7 +55,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -87,25 +89,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bw6-633/plonk/verify.go b/backend/plonk/bw6-633/verify.go similarity index 77% rename from internal/backend/bw6-633/plonk/verify.go rename to backend/plonk/bw6-633/verify.go index 4d8d69d051..63c4b1dbfb 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/backend/plonk/bw6-633/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bw6_633").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bw6-633").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 +} + +// VerifyingKey stores the data needed to verify a proof: +// * The commitment scheme +// * Commitments of ql prepended with as many ones as there are public inputs +// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs +// * Commitments to S1, S2, S3 +type VerifyingKey struct { + + // Size circuit + Size uint64 + SizeInv fr.Element + Generator fr.Element + NbPublicVariables uint64 + + // Commitment scheme that is used for an instantiation of PLONK + Kzg kzg.VerifyingKey + + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + + // S commitments to S1, S2, S3 + S [3]kzg.Digest + + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. + // In particular Qk is not complete. + Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial +} + +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + + var pk ProvingKey + var vk VerifyingKey + pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) + + // step 0: set the fft domains + pk.initDomains(spr) + + // step 1: set the verifying key + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) + vk.Size = pk.Domain[0].Cardinality + vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) + vk.Generator.Set(&pk.Domain[0].Generator) + vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk + + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { + return nil, nil, err + } + + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover + } + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) + + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} + +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { + + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() + + var err error + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } + } + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err + } + return nil +} + +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } + +} + +// buildPermutation builds the Permutation associated with a circuit. +// +// The permutation s is composed of cycles of maximum length such that +// +// s. (l∥r∥o) = (l∥r∥o) +// +// , where l∥r∥o is the concatenation of the indices of l, r, o in +// ql.l+qr.r+qm.l.r+qo.O+k = 0. +// +// The permutation is encoded as a slice s of size 3*size(l), where the +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab +// like this: for i in tab: tab[i] = tab[permutation[i]] +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { + + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution + + // init permutation + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 + } + + // init LRO position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID + for i := 0; i < len(spr.Public); i++ { + lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) + } + + offset := len(spr.Public) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ + } + + // init cycle: + // map ID -> last position the ID was seen + cycle := make([]int64, nbVariables) + for i := 0; i < len(cycle); i++ { + cycle[i] = -1 + } + + for i := 0; i < len(lro); i++ { + if cycle[lro[i]] != -1 { + // if != -1, it means we already encountered this value + // so we need to set the corresponding permutation index. + permutation[i] = cycle[lro[i]] + } + cycle[lro[i]] = int64(i) + } + + // complete the Permutation by filling the first IDs encountered + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] + } + } + + pt.S = permutation +} + +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { + + nbElmts := int(domain.Cardinality) + + var res [3]*iop.Polynomial + + // Lagrange form of ID + evaluationIDSmallDomain := getSupportPermutation(domain) + + // Lagrange form of S1, S2, S3 + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + } + + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res +} + +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} diff --git a/internal/backend/bw6-761/plonk/verify.go b/backend/plonk/bw6-761/verify.go similarity index 77% rename from internal/backend/bw6-761/plonk/verify.go rename to backend/plonk/bw6-761/verify.go index 9a1dd7650f..51c0719b8d 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/backend/plonk/bw6-761/verify.go @@ -39,7 +39,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "bw6_761").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "bw6-761").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -51,7 +51,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -85,25 +85,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls12-377/plonkfri/verify.go b/backend/plonkfri/bls12-377/verify.go similarity index 100% rename from internal/backend/bls12-377/plonkfri/verify.go rename to backend/plonkfri/bls12-377/verify.go diff --git a/internal/backend/bls12-381/plonkfri/prove.go b/backend/plonkfri/bls12-381/prove.go similarity index 92% rename from internal/backend/bls12-381/plonkfri/prove.go rename to backend/plonkfri/bls12-381/prove.go index 8909988986..d121b59a37 100644 --- a/internal/backend/bls12-381/plonkfri/prove.go +++ b/backend/plonkfri/bls12-381/prove.go @@ -22,11 +22,13 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fri" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bls12-381/plonkfri/setup.go b/backend/plonkfri/bls12-381/setup.go similarity index 91% rename from internal/backend/bls12-381/plonkfri/setup.go rename to backend/plonkfri/bls12-381/setup.go index da30df7b0f..11f574f5f2 100644 --- a/internal/backend/bls12-381/plonkfri/setup.go +++ b/backend/plonkfri/bls12-381/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fri" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" ) // ProvingKey stores the data needed to generate a proof: @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls12-381/plonkfri/verify.go b/backend/plonkfri/bls12-381/verify.go similarity index 100% rename from internal/backend/bls12-381/plonkfri/verify.go rename to backend/plonkfri/bls12-381/verify.go diff --git a/internal/backend/bls24-315/plonkfri/prove.go b/backend/plonkfri/bls24-315/prove.go similarity index 92% rename from internal/backend/bls24-315/plonkfri/prove.go rename to backend/plonkfri/bls24-315/prove.go index a440c5953f..cb1f43fcee 100644 --- a/internal/backend/bls24-315/plonkfri/prove.go +++ b/backend/plonkfri/bls24-315/prove.go @@ -22,11 +22,13 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fri" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bls24-315/plonkfri/setup.go b/backend/plonkfri/bls24-315/setup.go similarity index 91% rename from internal/backend/bls24-315/plonkfri/setup.go rename to backend/plonkfri/bls24-315/setup.go index d02f7f504f..c3c0837f90 100644 --- a/internal/backend/bls24-315/plonkfri/setup.go +++ b/backend/plonkfri/bls24-315/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fri" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" ) // ProvingKey stores the data needed to generate a proof: @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls24-315/plonkfri/verify.go b/backend/plonkfri/bls24-315/verify.go similarity index 100% rename from internal/backend/bls24-315/plonkfri/verify.go rename to backend/plonkfri/bls24-315/verify.go diff --git a/internal/backend/bls24-317/plonkfri/prove.go b/backend/plonkfri/bls24-317/prove.go similarity index 92% rename from internal/backend/bls24-317/plonkfri/prove.go rename to backend/plonkfri/bls24-317/prove.go index 752c825069..5fc6cbf713 100644 --- a/internal/backend/bls24-317/plonkfri/prove.go +++ b/backend/plonkfri/bls24-317/prove.go @@ -22,11 +22,13 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fri" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bls24-317/plonkfri/setup.go b/backend/plonkfri/bls24-317/setup.go similarity index 91% rename from internal/backend/bls24-317/plonkfri/setup.go rename to backend/plonkfri/bls24-317/setup.go index 7cd1fb80ff..54668b98f8 100644 --- a/internal/backend/bls24-317/plonkfri/setup.go +++ b/backend/plonkfri/bls24-317/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fri" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" ) // ProvingKey stores the data needed to generate a proof: @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bls24-317/plonkfri/verify.go b/backend/plonkfri/bls24-317/verify.go similarity index 100% rename from internal/backend/bls24-317/plonkfri/verify.go rename to backend/plonkfri/bls24-317/verify.go diff --git a/internal/backend/bn254/plonkfri/prove.go b/backend/plonkfri/bn254/prove.go similarity index 92% rename from internal/backend/bn254/plonkfri/prove.go rename to backend/plonkfri/bn254/prove.go index 1166b5cc30..161ad667f4 100644 --- a/internal/backend/bn254/plonkfri/prove.go +++ b/backend/plonkfri/bn254/prove.go @@ -22,11 +22,13 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fri" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bn254/plonkfri/setup.go b/backend/plonkfri/bn254/setup.go similarity index 91% rename from internal/backend/bn254/plonkfri/setup.go rename to backend/plonkfri/bn254/setup.go index 4a47a70839..9f69648ed1 100644 --- a/internal/backend/bn254/plonkfri/setup.go +++ b/backend/plonkfri/bn254/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fri" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" ) // ProvingKey stores the data needed to generate a proof: @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bn254/plonkfri/verify.go b/backend/plonkfri/bn254/verify.go similarity index 100% rename from internal/backend/bn254/plonkfri/verify.go rename to backend/plonkfri/bn254/verify.go diff --git a/internal/backend/bw6-633/plonkfri/prove.go b/backend/plonkfri/bw6-633/prove.go similarity index 92% rename from internal/backend/bw6-633/plonkfri/prove.go rename to backend/plonkfri/bw6-633/prove.go index aa02bc934b..e71df6e7aa 100644 --- a/internal/backend/bw6-633/plonkfri/prove.go +++ b/backend/plonkfri/bw6-633/prove.go @@ -22,11 +22,13 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fri" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bw6-633/plonkfri/setup.go b/backend/plonkfri/bw6-633/setup.go similarity index 91% rename from internal/backend/bw6-633/plonkfri/setup.go rename to backend/plonkfri/bw6-633/setup.go index 04d0062321..1384eed0bd 100644 --- a/internal/backend/bw6-633/plonkfri/setup.go +++ b/backend/plonkfri/bw6-633/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fri" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" ) // ProvingKey stores the data needed to generate a proof: @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bw6-633/plonkfri/verify.go b/backend/plonkfri/bw6-633/verify.go similarity index 100% rename from internal/backend/bw6-633/plonkfri/verify.go rename to backend/plonkfri/bw6-633/verify.go diff --git a/internal/backend/bw6-761/plonkfri/prove.go b/backend/plonkfri/bw6-761/prove.go similarity index 92% rename from internal/backend/bw6-761/plonkfri/prove.go rename to backend/plonkfri/bw6-761/prove.go index e5f9cb5b78..9092580485 100644 --- a/internal/backend/bw6-761/plonkfri/prove.go +++ b/backend/plonkfri/bw6-761/prove.go @@ -22,11 +22,13 @@ import ( "math/bits" "runtime" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fri" @@ -36,7 +38,6 @@ import ( ) type Proof struct { - // commitments to the solution vectors LROpp [3]fri.ProofOfProximity @@ -67,7 +68,11 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } var proof Proof @@ -78,23 +83,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // 1 - solve the system - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err } - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) // 2 - commit to lro blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( @@ -119,9 +116,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // 3 - compute Z, challenges are derived using L, R, O + public inputs + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness + } dataFiatShamir := make([][fr.Bytes]byte, len(spr.Public)+3) for i := 0; i < len(spr.Public); i++ { - copy(dataFiatShamir[i][:], fullWitness[i].Marshal()) + copy(dataFiatShamir[i][:], fw[i].Marshal()) } copy(dataFiatShamir[len(spr.Public)][:], proof.LROpp[0].ID) copy(dataFiatShamir[len(spr.Public)+1][:], proof.LROpp[1].ID) @@ -164,7 +165,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // alpha.SetUint64(11) evaluationQkCompleteDomainBigBitReversed := make([]fr.Element, pk.Domain[1].Cardinality) - copy(evaluationQkCompleteDomainBigBitReversed, fullWitness[:len(spr.Public)]) + copy(evaluationQkCompleteDomainBigBitReversed, fw[:len(spr.Public)]) copy(evaluationQkCompleteDomainBigBitReversed[len(spr.Public):], pk.LQkIncompleteDomainSmall[len(spr.Public):]) pk.Domain[0].FFTInverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) fft.BitReverse(evaluationQkCompleteDomainBigBitReversed[:pk.Domain[0].Cardinality]) @@ -487,7 +488,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse for i := 0; i < int(pk.Domain[0].Cardinality); i++ { startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.Domain[1].FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, fft.OnCoset()) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -515,7 +516,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.Domain[1].FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, fft.OnCoset()) // degree of hi is n+2 because of the blinding h1 := h[:pk.Domain[0].Cardinality+2] @@ -586,41 +587,6 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeBlindedLROCanonical // l, r, o in canonical basis with blinding func computeBlindedLROCanonical( diff --git a/internal/backend/bw6-761/plonkfri/setup.go b/backend/plonkfri/bw6-761/setup.go similarity index 91% rename from internal/backend/bw6-761/plonkfri/setup.go rename to backend/plonkfri/bw6-761/setup.go index 32be4d26fe..a08a528d80 100644 --- a/internal/backend/bw6-761/plonkfri/setup.go +++ b/backend/plonkfri/bw6-761/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fri" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" ) // ProvingKey stores the data needed to generate a proof: @@ -106,7 +106,7 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { // The verifying key shares data with the proving key pk.Vk = &vk - nbConstraints := len(spr.Constraints) + nbConstraints := spr.GetNbConstraints() // fft domains sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints @@ -156,15 +156,18 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { pk.CQkIncomplete[i].Set(&pk.LQkIncompleteDomainSmall[i]) // --> to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -209,10 +212,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -257,10 +260,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c != nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -349,9 +356,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -379,9 +386,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/backend/bw6-761/plonkfri/verify.go b/backend/plonkfri/bw6-761/verify.go similarity index 100% rename from internal/backend/bw6-761/plonkfri/verify.go rename to backend/plonkfri/bw6-761/verify.go diff --git a/backend/plonkfri/plonkfri.go b/backend/plonkfri/plonkfri.go index b2cab82fd6..5fcb3760eb 100644 --- a/backend/plonkfri/plonkfri.go +++ b/backend/plonkfri/plonkfri.go @@ -29,12 +29,13 @@ import ( cs_bw6633 "github.com/consensys/gnark/constraint/bw6-633" cs_bw6761 "github.com/consensys/gnark/constraint/bw6-761" - plonk_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/plonkfri" - plonk_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/plonkfri" - plonk_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/plonkfri" - plonk_bn254 "github.com/consensys/gnark/internal/backend/bn254/plonkfri" - plonk_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/plonkfri" - plonk_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/plonkfri" + plonk_bls12377 "github.com/consensys/gnark/backend/plonkfri/bls12-377" + plonk_bls12381 "github.com/consensys/gnark/backend/plonkfri/bls12-381" + plonk_bls24315 "github.com/consensys/gnark/backend/plonkfri/bls24-315" + plonk_bls24317 "github.com/consensys/gnark/backend/plonkfri/bls24-317" + plonk_bn254 "github.com/consensys/gnark/backend/plonkfri/bn254" + plonk_bw6633 "github.com/consensys/gnark/backend/plonkfri/bw6-633" + plonk_bw6761 "github.com/consensys/gnark/backend/plonkfri/bw6-761" fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -43,8 +44,6 @@ import ( fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - plonk_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/plonkfri" ) // Proof represents a Plonk proof generated by plonk.Prove @@ -106,60 +105,28 @@ func Setup(ccs constraint.ConstraintSystem) (ProvingKey, VerifyingKey, error) { // internally, the solution vector to the SparseR1CS will be filled with random values which may impact benchmarking func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { - // apply options - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return nil, err - } - switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bn254.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, opt) + return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), fullWitness, opts...) case *cs_bls12381.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls12381.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, opt) + return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), fullWitness, opts...) case *cs_bls12377.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls12377.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, opt) + return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), fullWitness, opts...) case *cs_bw6761.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bw6761.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, opt) + return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), fullWitness, opts...) case *cs_bw6633.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bw6633.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), w, opt) + return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), fullWitness, opts...) case *cs_bls24315.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls24315.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, opt) + return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), fullWitness, opts...) + case *cs_bls24317.SparseR1CS: - w, ok := fullWitness.Vector().(fr_bls24317.Vector) - if !ok { - return nil, witness.ErrInvalidWitness - } - return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), w, opt) + return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), fullWitness, opts...) + default: panic("unrecognized SparseR1CS curve type") } diff --git a/backend/witness/witness.go b/backend/witness/witness.go index 437ffcd35c..2db96aede3 100644 --- a/backend/witness/witness.go +++ b/backend/witness/witness.go @@ -113,7 +113,7 @@ func New(field *big.Int) (Witness, error) { } func (w *witness) Fill(nbPublic, nbSecret int, values <-chan any) error { - n := int(nbPublic + nbSecret) + n := nbPublic + nbSecret w.vector = resize(w.vector, n) w.nbPublic = uint32(nbPublic) w.nbSecret = uint32(nbSecret) diff --git a/backend/witness/witness_test.go b/backend/witness/witness_test.go index ba127b3240..40446abe81 100644 --- a/backend/witness/witness_test.go +++ b/backend/witness/witness_test.go @@ -1,6 +1,8 @@ package witness_test import ( + "bytes" + "encoding/json" "fmt" "reflect" "testing" @@ -47,11 +49,17 @@ func ExampleWitness() { // first get the circuit expected schema schema, _ := frontend.NewSchema(assignment) - json, _ := reconstructed.ToJSON(schema) + ret, _ := reconstructed.ToJSON(schema) - fmt.Println(string(json)) + var b bytes.Buffer + json.Indent(&b, ret, "", "\t") + fmt.Println(b.String()) // Output: - // {"X":42,"Y":8000,"E":1} + // { + // "X": 42, + // "Y": 8000, + // "E": 1 + // } } @@ -148,3 +156,34 @@ func roundTripMarshalJSON(assert *require.Assertions, assignment circuit, public assert.True(reflect.DeepEqual(rw, w), "witness json round trip serialization") } + +type initableVariable struct { + Val []frontend.Variable +} + +func (iv *initableVariable) GnarkInitHook() { + if iv.Val == nil { + iv.Val = []frontend.Variable{1, 2} // need to init value as are assigning to witness + } +} + +type initableCircuit struct { + X [2]initableVariable + Y []initableVariable + Z initableVariable +} + +func (c *initableCircuit) Define(api frontend.API) error { + panic("not called") +} + +func TestVariableInitHook(t *testing.T) { + assert := require.New(t) + + assignment := &initableCircuit{Y: make([]initableVariable, 2)} + w, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField()) + assert.NoError(err) + fw, ok := w.Vector().(fr.Vector) + assert.True(ok) + assert.Len(fw, 10, "invalid length") +} diff --git a/constraint/bls12-377/coeff.go b/constraint/bls12-377/coeff.go index cc8036f26f..a1b86b75c4 100644 --- a/constraint/bls12-377/coeff.go +++ b/constraint/bls12-377/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go new file mode 100644 index 0000000000..ef851cfb8f --- /dev/null +++ b/constraint/bls12-377/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bls12-377/r1cs.go b/constraint/bls12-377/r1cs.go deleted file mode 100644 index 8d678db082..0000000000 --- a/constraint/bls12-377/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS12_377 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-377/r1cs_sparse.go b/constraint/bls12-377/r1cs_sparse.go deleted file mode 100644 index ed01b36846..0000000000 --- a/constraint/bls12-377/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS12-377) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS12_377 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-377/r1cs_test.go b/constraint/bls12-377/r1cs_test.go index 926f9c1bb6..044c55d31a 100644 --- a/constraint/bls12-377/r1cs_test.go +++ b/constraint/bls12-377/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls12-377/solution.go b/constraint/bls12-377/solution.go deleted file mode 100644 index 8849a30acb..0000000000 --- a/constraint/bls12-377/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls12-377/solver.go b/constraint/bls12-377/solver.go new file mode 100644 index 0000000000..c992d55b4a --- /dev/null +++ b/constraint/bls12-377/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls12-377/system.go b/constraint/bls12-377/system.go new file mode 100644 index 0000000000..78b99eb451 --- /dev/null +++ b/constraint/bls12-377/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS12_377 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bls12-381/coeff.go b/constraint/bls12-381/coeff.go index f0ea2ac85f..9eb3786806 100644 --- a/constraint/bls12-381/coeff.go +++ b/constraint/bls12-381/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go new file mode 100644 index 0000000000..f70fb34a48 --- /dev/null +++ b/constraint/bls12-381/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bls12-381/r1cs.go b/constraint/bls12-381/r1cs.go deleted file mode 100644 index f818fb9f42..0000000000 --- a/constraint/bls12-381/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS12_381 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-381/r1cs_sparse.go b/constraint/bls12-381/r1cs_sparse.go deleted file mode 100644 index 393287d780..0000000000 --- a/constraint/bls12-381/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS12-381) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS12_381 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls12-381/r1cs_test.go b/constraint/bls12-381/r1cs_test.go index 90c9e1e849..28f77b956b 100644 --- a/constraint/bls12-381/r1cs_test.go +++ b/constraint/bls12-381/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" ) @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls12-381/solution.go b/constraint/bls12-381/solution.go deleted file mode 100644 index dbf96fadc4..0000000000 --- a/constraint/bls12-381/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls12-381/solver.go b/constraint/bls12-381/solver.go new file mode 100644 index 0000000000..125ce56e8e --- /dev/null +++ b/constraint/bls12-381/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls12-381/system.go b/constraint/bls12-381/system.go new file mode 100644 index 0000000000..399494e2f9 --- /dev/null +++ b/constraint/bls12-381/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS12_381 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bls24-315/coeff.go b/constraint/bls24-315/coeff.go index d91528f93e..084652f545 100644 --- a/constraint/bls24-315/coeff.go +++ b/constraint/bls24-315/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go new file mode 100644 index 0000000000..c7c2c9ed2a --- /dev/null +++ b/constraint/bls24-315/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bls24-315/r1cs.go b/constraint/bls24-315/r1cs.go deleted file mode 100644 index 76f1c63a4e..0000000000 --- a/constraint/bls24-315/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS24_315 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-315/r1cs_sparse.go b/constraint/bls24-315/r1cs_sparse.go deleted file mode 100644 index 69e3d01d03..0000000000 --- a/constraint/bls24-315/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS24-315) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS24_315 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-315/r1cs_test.go b/constraint/bls24-315/r1cs_test.go index a5cf07e4a3..4c42f78ee5 100644 --- a/constraint/bls24-315/r1cs_test.go +++ b/constraint/bls24-315/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" ) @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls24-315/solution.go b/constraint/bls24-315/solution.go deleted file mode 100644 index f65a2821b1..0000000000 --- a/constraint/bls24-315/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls24-315/solver.go b/constraint/bls24-315/solver.go new file mode 100644 index 0000000000..41f89208ce --- /dev/null +++ b/constraint/bls24-315/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls24-315/system.go b/constraint/bls24-315/system.go new file mode 100644 index 0000000000..3c097951ca --- /dev/null +++ b/constraint/bls24-315/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS24_315 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bls24-317/coeff.go b/constraint/bls24-317/coeff.go index d011d2bffb..5df92edf1b 100644 --- a/constraint/bls24-317/coeff.go +++ b/constraint/bls24-317/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go new file mode 100644 index 0000000000..f38ac45c92 --- /dev/null +++ b/constraint/bls24-317/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bls24-317/r1cs.go b/constraint/bls24-317/r1cs.go deleted file mode 100644 index 889ec4bc13..0000000000 --- a/constraint/bls24-317/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BLS24_317 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-317/r1cs_sparse.go b/constraint/bls24-317/r1cs_sparse.go deleted file mode 100644 index 5acd85dd99..0000000000 --- a/constraint/bls24-317/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BLS24-317) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BLS24_317 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bls24-317/r1cs_test.go b/constraint/bls24-317/r1cs_test.go index eff1f06499..40bb573a0b 100644 --- a/constraint/bls24-317/r1cs_test.go +++ b/constraint/bls24-317/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" ) @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bls24-317/solution.go b/constraint/bls24-317/solution.go deleted file mode 100644 index d26c56f931..0000000000 --- a/constraint/bls24-317/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bls24-317/solver.go b/constraint/bls24-317/solver.go new file mode 100644 index 0000000000..e1dcad13e9 --- /dev/null +++ b/constraint/bls24-317/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bls24-317/system.go b/constraint/bls24-317/system.go new file mode 100644 index 0000000000..f170242f5b --- /dev/null +++ b/constraint/bls24-317/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BLS24_317 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/blueprint.go b/constraint/blueprint.go new file mode 100644 index 0000000000..2949f0665f --- /dev/null +++ b/constraint/blueprint.go @@ -0,0 +1,73 @@ +package constraint + +type BlueprintID uint32 + +// Blueprint enable representing heterogeneous constraints or instructions in a constraint system +// in a memory efficient way. Blueprints essentially help the frontend/ to "compress" +// constraints or instructions, and specify for the solving (or zksnark setup) part how to +// "decompress" and optionally "solve" the associated wires. +type Blueprint interface { + // CalldataSize return the number of calldata input this blueprint expects. + // If this is unknown at compile time, implementation must return -1 and store + // the actual number of inputs in the first index of the calldata. + CalldataSize() int + + // NbConstraints return the number of constraints this blueprint creates. + NbConstraints() int + + // NbOutputs return the number of output wires this blueprint creates. + NbOutputs(inst Instruction) int + + // WireWalker returns a function that walks the wires appearing in the blueprint. + // This is used by the level builder to build a dependency graph between instructions. + WireWalker(inst Instruction) func(cb func(wire uint32)) +} + +// Solver represents the state of a constraint system solver at runtime. Blueprint can interact +// with this object to perform run time logic, solve constraints and assign values in the solution. +type Solver interface { + Field + + GetValue(cID, vID uint32) Element + GetCoeff(cID uint32) Element + SetValue(vID uint32, f Element) + IsSolved(vID uint32) bool + + // Read interprets input calldata as a LinearExpression, + // evaluates it and return the result and the number of uint32 word read. + Read(calldata []uint32) (Element, int) +} + +// BlueprintSolvable represents a blueprint that knows how to solve itself. +type BlueprintSolvable interface { + Blueprint + // Solve may return an error if the decoded constraint / calldata is unsolvable. + Solve(s Solver, instruction Instruction) error +} + +// BlueprintR1C indicates that the blueprint and associated calldata encodes a R1C +type BlueprintR1C interface { + Blueprint + CompressR1C(c *R1C, to *[]uint32) + DecompressR1C(into *R1C, instruction Instruction) +} + +// BlueprintSparseR1C indicates that the blueprint and associated calldata encodes a SparseR1C. +type BlueprintSparseR1C interface { + Blueprint + CompressSparseR1C(c *SparseR1C, to *[]uint32) + DecompressSparseR1C(into *SparseR1C, instruction Instruction) +} + +// BlueprintHint indicates that the blueprint and associated calldata encodes a hint. +type BlueprintHint interface { + Blueprint + CompressHint(h HintMapping, to *[]uint32) + DecompressHint(h *HintMapping, instruction Instruction) +} + +// Compressible represent an object that knows how to encode itself as a []uint32. +type Compressible interface { + // Compress interprets the objects as a LinearExpression and encodes it as a []uint32. + Compress(to *[]uint32) +} diff --git a/constraint/blueprint_hint.go b/constraint/blueprint_hint.go new file mode 100644 index 0000000000..ac96413ef0 --- /dev/null +++ b/constraint/blueprint_hint.go @@ -0,0 +1,94 @@ +package constraint + +import ( + "github.com/consensys/gnark/constraint/solver" +) + +type BlueprintGenericHint struct{} + +func (b *BlueprintGenericHint) DecompressHint(h *HintMapping, inst Instruction) { + // ignore first call data == nbInputs + h.HintID = solver.HintID(inst.Calldata[1]) + lenInputs := int(inst.Calldata[2]) + if cap(h.Inputs) >= lenInputs { + h.Inputs = h.Inputs[:lenInputs] + } else { + h.Inputs = make([]LinearExpression, lenInputs) + } + + j := 3 + for i := 0; i < lenInputs; i++ { + n := int(inst.Calldata[j]) // len of linear expr + j++ + if cap(h.Inputs[i]) >= n { + h.Inputs[i] = h.Inputs[i][:0] + } else { + h.Inputs[i] = make(LinearExpression, 0, n) + } + for k := 0; k < n; k++ { + h.Inputs[i] = append(h.Inputs[i], Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]}) + j += 2 + } + } + h.OutputRange.Start = inst.Calldata[j] + h.OutputRange.End = inst.Calldata[j+1] +} + +func (b *BlueprintGenericHint) CompressHint(h HintMapping, to *[]uint32) { + nbInputs := 1 // storing nb inputs + nbInputs++ // hintID + nbInputs++ // len(h.Inputs) + for i := 0; i < len(h.Inputs); i++ { + nbInputs++ // len of h.Inputs[i] + nbInputs += len(h.Inputs[i]) * 2 + } + + nbInputs += 2 // output range start / end + + (*to) = append((*to), uint32(nbInputs)) + (*to) = append((*to), uint32(h.HintID)) + (*to) = append((*to), uint32(len(h.Inputs))) + + for _, l := range h.Inputs { + (*to) = append((*to), uint32(len(l))) + for _, t := range l { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } + } + + (*to) = append((*to), h.OutputRange.Start) + (*to) = append((*to), h.OutputRange.End) +} + +func (b *BlueprintGenericHint) CalldataSize() int { + return -1 +} +func (b *BlueprintGenericHint) NbConstraints() int { + return 0 +} + +func (b *BlueprintGenericHint) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintGenericHint) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + lenInputs := int(inst.Calldata[2]) + j := 3 + for i := 0; i < lenInputs; i++ { + n := int(inst.Calldata[j]) // len of linear expr + j++ + + for k := 0; k < n; k++ { + t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]} + if !t.IsConstant() { + cb(t.VID) + } + j += 2 + } + } + for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ { + cb(k) + } + } +} diff --git a/constraint/blueprint_logderivlookup.go b/constraint/blueprint_logderivlookup.go new file mode 100644 index 0000000000..b46f6d8f01 --- /dev/null +++ b/constraint/blueprint_logderivlookup.go @@ -0,0 +1,111 @@ +package constraint + +import ( + "fmt" +) + +// TODO @gbotrel this shouldn't be there, but we need to figure out a clean way to serialize +// blueprints + +// BlueprintLookupHint is a blueprint that facilitates the lookup of values in a table. +// It is essentially a hint to the solver, but enables storing the table entries only once. +type BlueprintLookupHint struct { + EntriesCalldata []uint32 +} + +// ensures BlueprintLookupHint implements the BlueprintSolvable interface +var _ BlueprintSolvable = (*BlueprintLookupHint)(nil) + +func (b *BlueprintLookupHint) Solve(s Solver, inst Instruction) error { + nbEntries := int(inst.Calldata[1]) + entries := make([]Element, nbEntries) + + // read the static entries from the blueprint + // TODO @gbotrel cache that. + offset, delta := 0, 0 + for i := 0; i < nbEntries; i++ { + entries[i], delta = s.Read(b.EntriesCalldata[offset:]) + offset += delta + } + + nbInputs := int(inst.Calldata[2]) + + // read the inputs from the instruction + inputs := make([]Element, nbInputs) + offset = 3 + for i := 0; i < nbInputs; i++ { + inputs[i], delta = s.Read(inst.Calldata[offset:]) + offset += delta + } + + // set the outputs + nbOutputs := nbInputs + + for i := 0; i < nbOutputs; i++ { + idx, isUint64 := s.Uint64(inputs[i]) + if !isUint64 || idx >= uint64(len(entries)) { + return fmt.Errorf("lookup query too large") + } + // we set the output wire to the value of the entry + s.SetValue(uint32(i+int(inst.WireOffset)), entries[idx]) + } + return nil +} + +func (b *BlueprintLookupHint) CalldataSize() int { + // variable size + return -1 +} +func (b *BlueprintLookupHint) NbConstraints() int { + return 0 +} + +// NbOutputs return the number of output wires this blueprint creates. +func (b *BlueprintLookupHint) NbOutputs(inst Instruction) int { + return int(inst.Calldata[2]) +} + +// Wires returns a function that walks the wires appearing in the blueprint. +// This is used by the level builder to build a dependency graph between instructions. +func (b *BlueprintLookupHint) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + // depend on the table UP to the number of entries at time of instruction creation. + nbEntries := int(inst.Calldata[1]) + + // invoke the callback on each wire appearing in the table + j := 0 + for i := 0; i < nbEntries; i++ { + // first we have the length of the linear expression + n := int(b.EntriesCalldata[j]) + j++ + for k := 0; k < n; k++ { + t := Term{CID: b.EntriesCalldata[j], VID: b.EntriesCalldata[j+1]} + if !t.IsConstant() { + cb(t.VID) + } + j += 2 + } + } + + // invoke the callback on each wire appearing in the inputs + nbInputs := int(inst.Calldata[2]) + j = 3 + for i := 0; i < nbInputs; i++ { + // first we have the length of the linear expression + n := int(inst.Calldata[j]) + j++ + for k := 0; k < n; k++ { + t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]} + if !t.IsConstant() { + cb(t.VID) + } + j += 2 + } + } + + // finally we have the outputs + for i := 0; i < nbInputs; i++ { + cb(uint32(i + int(inst.WireOffset))) + } + } +} diff --git a/constraint/blueprint_r1cs.go b/constraint/blueprint_r1cs.go new file mode 100644 index 0000000000..b231eda067 --- /dev/null +++ b/constraint/blueprint_r1cs.go @@ -0,0 +1,80 @@ +package constraint + +// BlueprintGenericR1C implements Blueprint and BlueprintR1C. +// Encodes +// +// L * R == 0 +type BlueprintGenericR1C struct{} + +func (b *BlueprintGenericR1C) CalldataSize() int { + // size of linear expressions are unknown. + return -1 +} +func (b *BlueprintGenericR1C) NbConstraints() int { + return 1 +} +func (b *BlueprintGenericR1C) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintGenericR1C) CompressR1C(c *R1C, to *[]uint32) { + // we store total nb inputs, len L, len R, len O, and then the "flatten" linear expressions + nbInputs := 4 + 2*(len(c.L)+len(c.R)+len(c.O)) + (*to) = append((*to), uint32(nbInputs)) + (*to) = append((*to), uint32(len(c.L)), uint32(len(c.R)), uint32(len(c.O))) + for _, t := range c.L { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } + for _, t := range c.R { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } + for _, t := range c.O { + (*to) = append((*to), uint32(t.CoeffID()), uint32(t.WireID())) + } +} + +func (b *BlueprintGenericR1C) DecompressR1C(c *R1C, inst Instruction) { + copySlice := func(slice *LinearExpression, expectedLen, idx int) { + if cap(*slice) >= expectedLen { + (*slice) = (*slice)[:expectedLen] + } else { + (*slice) = make(LinearExpression, expectedLen, expectedLen*2) + } + for k := 0; k < expectedLen; k++ { + (*slice)[k].CID = inst.Calldata[idx] + idx++ + (*slice)[k].VID = inst.Calldata[idx] + idx++ + } + } + + lenL := int(inst.Calldata[1]) + lenR := int(inst.Calldata[2]) + lenO := int(inst.Calldata[3]) + + const offset = 4 + copySlice(&c.L, lenL, offset) + copySlice(&c.R, lenR, offset+2*lenL) + copySlice(&c.O, lenO, offset+2*(lenL+lenR)) +} + +func (b *BlueprintGenericR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + lenL := int(inst.Calldata[1]) + lenR := int(inst.Calldata[2]) + lenO := int(inst.Calldata[3]) + + appendWires := func(expectedLen, idx int) { + for k := 0; k < expectedLen; k++ { + idx++ + cb(inst.Calldata[idx]) + idx++ + } + } + + const offset = 4 + appendWires(lenL, offset) + appendWires(lenR, offset+2*lenL) + appendWires(lenO, offset+2*(lenL+lenR)) + } +} diff --git a/constraint/blueprint_scs.go b/constraint/blueprint_scs.go new file mode 100644 index 0000000000..3f0973e234 --- /dev/null +++ b/constraint/blueprint_scs.go @@ -0,0 +1,305 @@ +package constraint + +import ( + "errors" + "fmt" +) + +var ( + errDivideByZero = errors.New("division by 0") + errBoolConstrain = errors.New("boolean constraint doesn't hold") +) + +// BlueprintGenericSparseR1C implements Blueprint and BlueprintSparseR1C. +// Encodes +// +// qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0 +type BlueprintGenericSparseR1C struct { +} + +func (b *BlueprintGenericSparseR1C) CalldataSize() int { + return 9 // number of fields in SparseR1C +} +func (b *BlueprintGenericSparseR1C) NbConstraints() int { + return 1 +} + +func (b *BlueprintGenericSparseR1C) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintGenericSparseR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + cb(inst.Calldata[1]) // xb + cb(inst.Calldata[2]) // xc + } +} + +func (b *BlueprintGenericSparseR1C) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.XB, c.XC, c.QL, c.QR, c.QO, c.QM, c.QC, uint32(c.Commitment)) +} + +func (b *BlueprintGenericSparseR1C) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + + c.XA = inst.Calldata[0] + c.XB = inst.Calldata[1] + c.XC = inst.Calldata[2] + c.QL = inst.Calldata[3] + c.QR = inst.Calldata[4] + c.QO = inst.Calldata[5] + c.QM = inst.Calldata[6] + c.QC = inst.Calldata[7] + c.Commitment = CommitmentConstraint(inst.Calldata[8]) +} + +func (b *BlueprintGenericSparseR1C) Solve(s Solver, inst Instruction) error { + var c SparseR1C + b.DecompressSparseR1C(&c, inst) + if c.Commitment != NOT { + // a constraint of the form f_L - PI_2 = 0 or f_L = Comm. + // these are there for enforcing the correctness of the commitment and can be skipped in solving time + return nil + } + + var ok bool + + // constraint has at most one unsolved wire. + if !s.IsSolved(c.XA) { + // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 + u1 := s.GetCoeff(c.QL) + den := s.GetValue(c.QM, c.XB) + den = s.Add(den, u1) + den, ok = s.Inverse(den) + if !ok { + return errDivideByZero + } + v1 := s.GetValue(c.QR, c.XB) + v2 := s.GetValue(c.QO, c.XC) + num := s.Add(v1, v2) + num = s.Add(num, s.GetCoeff(c.QC)) + num = s.Mul(num, den) + num = s.Neg(num) + s.SetValue(c.XA, num) + } else if !s.IsSolved(c.XB) { + u2 := s.GetCoeff(c.QR) + den := s.GetValue(c.QM, c.XA) + den = s.Add(den, u2) + den, ok = s.Inverse(den) + if !ok { + return errDivideByZero + } + + v1 := s.GetValue(c.QL, c.XA) + v2 := s.GetValue(c.QO, c.XC) + + num := s.Add(v1, v2) + num = s.Add(num, s.GetCoeff(c.QC)) + num = s.Mul(num, den) + num = s.Neg(num) + s.SetValue(c.XB, num) + + } else if !s.IsSolved(c.XC) { + // O we solve for O + l := s.GetValue(c.QL, c.XA) + r := s.GetValue(c.QR, c.XB) + m0 := s.GetValue(c.QM, c.XA) + m1 := s.GetValue(CoeffIdOne, c.XB) + + // o = - ((m0 * m1) + l + r + c.QC) / c.O + o := s.Mul(m0, m1) + o = s.Add(o, l) + o = s.Add(o, r) + o = s.Add(o, s.GetCoeff(c.QC)) + + den := s.GetCoeff(c.QO) + den, ok = s.Inverse(den) + if !ok { + return errDivideByZero + } + o = s.Mul(o, den) + o = s.Neg(o) + + s.SetValue(c.XC, o) + } else { + // all wires are solved, we verify that the constraint hold. + // this can happen when all wires are from hints or if the constraint is an assertion. + return b.checkConstraint(&c, s) + } + return nil +} + +func (b *BlueprintGenericSparseR1C) checkConstraint(c *SparseR1C, s Solver) error { + l := s.GetValue(c.QL, c.XA) + r := s.GetValue(c.QR, c.XB) + m0 := s.GetValue(c.QM, c.XA) + m1 := s.GetValue(CoeffIdOne, c.XB) + m0 = s.Mul(m0, m1) + o := s.GetValue(c.QO, c.XC) + qC := s.GetCoeff(c.QC) + + t := s.Add(m0, l) + t = s.Add(t, r) + t = s.Add(t, o) + t = s.Add(t, qC) + + if !t.IsZero() { + return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + %s + %s != 0", + s.String(l), + s.String(r), + s.String(o), + s.String(m0), + s.String(qC), + ) + } + return nil +} + +// BlueprintSparseR1CMul implements Blueprint, BlueprintSolvable and BlueprintSparseR1C. +// Encodes +// +// qM⋅(xaxb) == xc +type BlueprintSparseR1CMul struct{} + +func (b *BlueprintSparseR1CMul) CalldataSize() int { + return 4 +} +func (b *BlueprintSparseR1CMul) NbConstraints() int { + return 1 +} +func (b *BlueprintSparseR1CMul) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintSparseR1CMul) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + cb(inst.Calldata[1]) // xb + cb(inst.Calldata[2]) // xc + } +} + +func (b *BlueprintSparseR1CMul) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.XB, c.XC, c.QM) +} + +func (b *BlueprintSparseR1CMul) Solve(s Solver, inst Instruction) error { + // qM⋅(xaxb) == xc + m0 := s.GetValue(inst.Calldata[3], inst.Calldata[0]) + m1 := s.GetValue(CoeffIdOne, inst.Calldata[1]) + + m0 = s.Mul(m0, m1) + + s.SetValue(inst.Calldata[2], m0) + return nil +} + +func (b *BlueprintSparseR1CMul) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + c.XA = inst.Calldata[0] + c.XB = inst.Calldata[1] + c.XC = inst.Calldata[2] + c.QO = CoeffIdMinusOne + c.QM = inst.Calldata[3] +} + +// BlueprintSparseR1CAdd implements Blueprint, BlueprintSolvable and BlueprintSparseR1C. +// Encodes +// +// qL⋅xa + qR⋅xb + qC == xc +type BlueprintSparseR1CAdd struct{} + +func (b *BlueprintSparseR1CAdd) CalldataSize() int { + return 6 +} +func (b *BlueprintSparseR1CAdd) NbConstraints() int { + return 1 +} +func (b *BlueprintSparseR1CAdd) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintSparseR1CAdd) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + cb(inst.Calldata[1]) // xb + cb(inst.Calldata[2]) // xc + } +} + +func (b *BlueprintSparseR1CAdd) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.XB, c.XC, c.QL, c.QR, c.QC) +} + +func (blueprint *BlueprintSparseR1CAdd) Solve(s Solver, inst Instruction) error { + // a + b + k == c + a := s.GetValue(inst.Calldata[3], inst.Calldata[0]) + b := s.GetValue(inst.Calldata[4], inst.Calldata[1]) + k := s.GetCoeff(inst.Calldata[5]) + + a = s.Add(a, b) + a = s.Add(a, k) + + s.SetValue(inst.Calldata[2], a) + return nil +} + +func (b *BlueprintSparseR1CAdd) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + c.XA = inst.Calldata[0] + c.XB = inst.Calldata[1] + c.XC = inst.Calldata[2] + c.QL = inst.Calldata[3] + c.QR = inst.Calldata[4] + c.QO = CoeffIdMinusOne + c.QC = inst.Calldata[5] +} + +// BlueprintSparseR1CBool implements Blueprint, BlueprintSolvable and BlueprintSparseR1C. +// Encodes +// +// qL⋅xa + qM⋅(xa*xa) == 0 +// that is v + -v*v == 0 +type BlueprintSparseR1CBool struct{} + +func (b *BlueprintSparseR1CBool) CalldataSize() int { + return 3 +} +func (b *BlueprintSparseR1CBool) NbConstraints() int { + return 1 +} +func (b *BlueprintSparseR1CBool) NbOutputs(inst Instruction) int { + return 0 +} + +func (b *BlueprintSparseR1CBool) WireWalker(inst Instruction) func(cb func(wire uint32)) { + return func(cb func(wire uint32)) { + cb(inst.Calldata[0]) // xa + } +} + +func (b *BlueprintSparseR1CBool) CompressSparseR1C(c *SparseR1C, to *[]uint32) { + *to = append(*to, c.XA, c.QL, c.QM) +} + +func (blueprint *BlueprintSparseR1CBool) Solve(s Solver, inst Instruction) error { + // all wires are already solved, we just check the constraint. + v1 := s.GetValue(inst.Calldata[1], inst.Calldata[0]) + v2 := s.GetValue(inst.Calldata[2], inst.Calldata[0]) + v := s.GetValue(CoeffIdOne, inst.Calldata[0]) + v = s.Mul(v, v2) + v = s.Add(v1, v) + if !v.IsZero() { + return errBoolConstrain + } + return nil +} + +func (b *BlueprintSparseR1CBool) DecompressSparseR1C(c *SparseR1C, inst Instruction) { + c.Clear() + c.XA = inst.Calldata[0] + c.XB = c.XA + c.QL = inst.Calldata[1] + c.QM = inst.Calldata[2] +} diff --git a/constraint/bn254/coeff.go b/constraint/bn254/coeff.go index 2835fd1555..da49b0e68b 100644 --- a/constraint/bn254/coeff.go +++ b/constraint/bn254/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go new file mode 100644 index 0000000000..8b3ea26755 --- /dev/null +++ b/constraint/bn254/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bn254/r1cs.go b/constraint/bn254/r1cs.go deleted file mode 100644 index 5bc4cb56aa..0000000000 --- a/constraint/bn254/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BN254 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bn254/r1cs_sparse.go b/constraint/bn254/r1cs_sparse.go deleted file mode 100644 index 4a9df0856e..0000000000 --- a/constraint/bn254/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BN254) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BN254 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bn254/r1cs_test.go b/constraint/bn254/r1cs_test.go index a772621bcc..d295603ebf 100644 --- a/constraint/bn254/r1cs_test.go +++ b/constraint/bn254/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bn254/solution.go b/constraint/bn254/solution.go deleted file mode 100644 index 3346fc9f1c..0000000000 --- a/constraint/bn254/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bn254/solver.go b/constraint/bn254/solver.go new file mode 100644 index 0000000000..5038f2ef4b --- /dev/null +++ b/constraint/bn254/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bn254/system.go b/constraint/bn254/system.go new file mode 100644 index 0000000000..fc9e816703 --- /dev/null +++ b/constraint/bn254/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BN254 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bw6-633/coeff.go b/constraint/bw6-633/coeff.go index 48b63963d7..dd1d29ad18 100644 --- a/constraint/bw6-633/coeff.go +++ b/constraint/bw6-633/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go new file mode 100644 index 0000000000..8018e76d3d --- /dev/null +++ b/constraint/bw6-633/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bw6-633/r1cs.go b/constraint/bw6-633/r1cs.go deleted file mode 100644 index b73a94b155..0000000000 --- a/constraint/bw6-633/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BW6_633 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-633/r1cs_sparse.go b/constraint/bw6-633/r1cs_sparse.go deleted file mode 100644 index 0cba25ce7b..0000000000 --- a/constraint/bw6-633/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BW6-633) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BW6_633 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-633/r1cs_test.go b/constraint/bw6-633/r1cs_test.go index 0e07d67a8a..7111c25c77 100644 --- a/constraint/bw6-633/r1cs_test.go +++ b/constraint/bw6-633/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" ) @@ -48,7 +49,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -76,10 +77,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -147,12 +148,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -161,8 +156,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bw6-633/solution.go b/constraint/bw6-633/solution.go deleted file mode 100644 index aea0b28a49..0000000000 --- a/constraint/bw6-633/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bw6-633/solver.go b/constraint/bw6-633/solver.go new file mode 100644 index 0000000000..948bb4a2e4 --- /dev/null +++ b/constraint/bw6-633/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bw6-633/system.go b/constraint/bw6-633/system.go new file mode 100644 index 0000000000..555ed8fb9e --- /dev/null +++ b/constraint/bw6-633/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BW6_633 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bw6-761/coeff.go b/constraint/bw6-761/coeff.go index e63c24f7ca..506ee5e586 100644 --- a/constraint/bw6-761/coeff.go +++ b/constraint/bw6-761/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go new file mode 100644 index 0000000000..d69ee82d31 --- /dev/null +++ b/constraint/bw6-761/gkr.go @@ -0,0 +1,196 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) diff --git a/constraint/bw6-761/r1cs.go b/constraint/bw6-761/r1cs.go deleted file mode 100644 index f7347529f8..0000000000 --- a/constraint/bw6-761/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.BW6_761 -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-761/r1cs_sparse.go b/constraint/bw6-761/r1cs_sparse.go deleted file mode 100644 index 6f9f1ff048..0000000000 --- a/constraint/bw6-761/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.BW6-761) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.BW6_761 -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/bw6-761/r1cs_test.go b/constraint/bw6-761/r1cs_test.go index ae50dafdc3..d8f3b48039 100644 --- a/constraint/bw6-761/r1cs_test.go +++ b/constraint/bw6-761/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -27,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" ) @@ -51,7 +52,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -79,10 +80,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -150,12 +151,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -164,8 +159,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/bw6-761/solution.go b/constraint/bw6-761/solution.go deleted file mode 100644 index 54e9eee0bb..0000000000 --- a/constraint/bw6-761/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/bw6-761/solver.go b/constraint/bw6-761/solver.go new file mode 100644 index 0000000000..803035e1a9 --- /dev/null +++ b/constraint/bw6-761/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/bw6-761/system.go b/constraint/bw6-761/system.go new file mode 100644 index 0000000000..f08c43b4ea --- /dev/null +++ b/constraint/bw6-761/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.BW6_761 +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/coeff.go b/constraint/coeff.go deleted file mode 100644 index 434919269f..0000000000 --- a/constraint/coeff.go +++ /dev/null @@ -1,39 +0,0 @@ -package constraint - -import ( - "math/big" -) - -// Coeff represents a term coefficient data. It is instantiated by the concrete -// constraint system implementation. -// Most of the scalar field used in gnark are on 4 uint64, so we have a clear memory overhead here. -type Coeff [6]uint64 - -// IsZero returns true if coefficient == 0 -func (z *Coeff) IsZero() bool { - return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 -} - -// CoeffEngine capability to perform arithmetic on Coeff -type CoeffEngine interface { - FromInterface(interface{}) Coeff - ToBigInt(*Coeff) *big.Int - Mul(a, b *Coeff) - Add(a, b *Coeff) - Sub(a, b *Coeff) - Neg(a *Coeff) - Inverse(a *Coeff) - One() Coeff - IsOne(*Coeff) bool - String(*Coeff) string -} - -// ids of the coefficients with simple values in any cs.coeffs slice. -// TODO @gbotrel let's keep that here for the refactoring -- and move it to concrete cs package after -const ( - CoeffIdZero = 0 - CoeffIdOne = 1 - CoeffIdTwo = 2 - CoeffIdMinusOne = 3 - CoeffIdMinusTwo = 4 -) diff --git a/constraint/commitment.go b/constraint/commitment.go index bc6a0e8d0a..2e812ce906 100644 --- a/constraint/commitment.go +++ b/constraint/commitment.go @@ -1,62 +1,109 @@ package constraint import ( + "github.com/consensys/gnark/constraint/solver" "math/big" - - "github.com/consensys/gnark/backend/hint" ) const CommitmentDst = "bsb22-commitment" -type Commitment struct { - Committed []int // sorted list of id's of committed variables - NbPrivateCommitted int - HintID hint.ID // TODO @gbotrel we probably don't need that here - CommitmentIndex int - CommittedAndCommitment []int // sorted list of id's of committed variables AND the commitment itself +type Groth16Commitment struct { + PublicAndCommitmentCommitted []int // PublicAndCommitmentCommitted sorted list of id's of public and commitment committed wires + PrivateCommitted []int // PrivateCommitted sorted list of id's of private/internal committed wires + CommitmentIndex int // CommitmentIndex the wire index of the commitment + HintID solver.HintID + NbPublicCommitted int } -func (i *Commitment) NbPublicCommitted() int { - return i.NbCommitted() - i.NbPrivateCommitted +type PlonkCommitment struct { + Committed []int // sorted list of id's of committed variables in groth16. in plonk, list of indexes of constraints defining committed values + CommitmentIndex int // CommitmentIndex index of the constraint defining the commitment + HintID solver.HintID } -func (i *Commitment) NbCommitted() int { - return len(i.Committed) +type Commitment interface{} +type Commitments interface{ CommitmentIndexes() []int } + +type Groth16Commitments []Groth16Commitment +type PlonkCommitments []PlonkCommitment + +func (c Groth16Commitments) CommitmentIndexes() []int { + commitmentWires := make([]int, len(c)) + for i := range c { + commitmentWires[i] = c[i].CommitmentIndex + } + return commitmentWires } -func (i *Commitment) Is() bool { - return len(i.Committed) != 0 +func (c PlonkCommitments) CommitmentIndexes() []int { + commitmentWires := make([]int, len(c)) + for i := range c { + commitmentWires[i] = c[i].CommitmentIndex + } + return commitmentWires } -// NewCommitment initialize a Commitment object -// - committed are the sorted wireID to commit to (without duplicate) -// - nbPublicCommited is the number of public inputs among the commited wireIDs -func NewCommitment(committed []int, nbPublicCommitted int) Commitment { - return Commitment{ - Committed: committed, - NbPrivateCommitted: len(committed) - nbPublicCommitted, +func (c Groth16Commitments) GetPrivateCommitted() [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = c[i].PrivateCommitted } + return res } -func (i *Commitment) SerializeCommitment(privateCommitment []byte, publicCommitted []*big.Int, fieldByteLen int) []byte { +// GetPublicAndCommitmentCommitted returns the list of public and commitment committed wires +// if committedTranslationList is not nil, commitment indexes are translated into their relative positions on the list plus the offset +func (c Groth16Commitments) GetPublicAndCommitmentCommitted(committedTranslationList []int, offset int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, len(c[i].PublicAndCommitmentCommitted)) + copy(res[i], c[i].GetPublicCommitted()) + translatedCommitmentCommitted := res[i][c[i].NbPublicCommitted:] + commitmentCommitted := c[i].GetCommitmentCommitted() + // convert commitment indexes to verifier understandable ones + if committedTranslationList == nil { + copy(translatedCommitmentCommitted, commitmentCommitted) + } else { + k := 0 + for j := range translatedCommitmentCommitted { + for committedTranslationList[k] != commitmentCommitted[j] { + k++ + } // find it in the translation list + translatedCommitmentCommitted[j] = k + offset + } + } + } + return res +} + +func SerializeCommitment(privateCommitment []byte, publicCommitted []*big.Int, fieldByteLen int) []byte { res := make([]byte, len(privateCommitment)+len(publicCommitted)*fieldByteLen) copy(res, privateCommitment) offset := len(privateCommitment) - for j, inJ := range publicCommitted { - offset += j * fieldByteLen + for _, inJ := range publicCommitted { inJ.FillBytes(res[offset : offset+fieldByteLen]) + offset += fieldByteLen } return res } -// PrivateToPublic returns indexes of variables which are private to the constraint system, but public to Groth16. That is, private committed variables and the commitment itself -func (i *Commitment) PrivateToPublic() []int { - return i.CommittedAndCommitment[i.NbPublicCommitted():] +func NewCommitments(t SystemType) Commitments { + switch t { + case SystemR1CS: + return Groth16Commitments{} + case SystemSparseR1CS: + return PlonkCommitments{} + } + panic("unknown cs type") +} + +func (c Groth16Commitment) GetPublicCommitted() []int { + return c.PublicAndCommitmentCommitted[:c.NbPublicCommitted] } -func (i *Commitment) PrivateCommitted() []int { - return i.Committed[i.NbPublicCommitted():] +func (c Groth16Commitment) GetCommitmentCommitted() []int { + return c.PublicAndCommitmentCommitted[c.NbPublicCommitted:] } diff --git a/constraint/core.go b/constraint/core.go new file mode 100644 index 0000000000..12aac144cb --- /dev/null +++ b/constraint/core.go @@ -0,0 +1,469 @@ +package constraint + +import ( + "fmt" + "math/big" + "strconv" + "sync" + + "github.com/blang/semver/v4" + "github.com/consensys/gnark" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/profile" +) + +type SystemType uint16 + +const ( + SystemUnknown SystemType = iota + SystemR1CS + SystemSparseR1CS +) + +// PackedInstruction is the lowest element of a constraint system. It stores just enough data to +// reconstruct a constraint of any shape or a hint at solving time. +type PackedInstruction struct { + // BlueprintID maps this instruction to a blueprint + BlueprintID BlueprintID + + // ConstraintOffset stores the starting constraint ID of this instruction. + // Might not be strictly necessary; but speeds up solver for instructions that represents + // multiple constraints. + ConstraintOffset uint32 + + // WireOffset stores the starting internal wire ID of this instruction. Blueprints may use this + // and refer to output wires by their offset. + // For example, if a blueprint declared 5 outputs, the first output wire will be WireOffset, + // the last one WireOffset+4. + WireOffset uint32 + + // The constraint system stores a single []uint32 calldata slice. StartCallData + // points to the starting index in the mentioned slice. This avoid storing a slice per + // instruction (3 * uint64 in memory). + StartCallData uint64 +} + +// Unpack returns the instruction corresponding to the packed instruction. +func (pi PackedInstruction) Unpack(cs *System) Instruction { + + blueprint := cs.Blueprints[pi.BlueprintID] + cSize := blueprint.CalldataSize() + if cSize < 0 { + // by convention, we store nbInputs < 0 for non-static input length. + cSize = int(cs.CallData[pi.StartCallData]) + } + + return Instruction{ + ConstraintOffset: pi.ConstraintOffset, + WireOffset: pi.WireOffset, + Calldata: cs.CallData[pi.StartCallData : pi.StartCallData+uint64(cSize)], + } +} + +// Instruction is the lowest element of a constraint system. It stores all the data needed to +// reconstruct a constraint of any shape or a hint at solving time. +type Instruction struct { + ConstraintOffset uint32 + WireOffset uint32 + Calldata []uint32 +} + +// System contains core elements for a constraint System +type System struct { + // serialization header + GnarkVersion string + ScalarField string + + Type SystemType + + Instructions []PackedInstruction + Blueprints []Blueprint + CallData []uint32 // huge slice. + + // can be != than len(instructions) + NbConstraints int + + // number of internal wires + NbInternalVariables int + + // input wires names + Public, Secret []string + + // logs (added with system.Println, resolved when solver sets a value to a wire) + Logs []LogEntry + + // debug info contains stack trace (including line number) of a call to a system.API that + // results in an unsolved constraint + DebugInfo []LogEntry + SymbolTable debug.SymbolTable + // maps constraint id to debugInfo id + // several constraints may point to the same debug info + MDebug map[int]int + + // maps hintID to hint string identifier + MHintsDependencies map[solver.HintID]string + + // each level contains independent constraints and can be parallelized + // it is guaranteed that all dependencies for constraints in a level l are solved + // in previous levels + // TODO @gbotrel these are currently updated after we add a constraint. + // but in case the object is built from a serialized representation + // we need to init the level builder lbWireLevel from the existing constraints. + Levels [][]int + + // scalar field + q *big.Int `cbor:"-"` + bitLen int `cbor:"-"` + + // level builder + lbWireLevel []int `cbor:"-"` // at which level we solve a wire. init at -1. + lbOutputs []uint32 `cbor:"-"` // wire outputs for current constraint. + + CommitmentInfo Commitments + GkrInfo GkrInfo + + genericHint BlueprintID +} + +// NewSystem initialize the common structure among constraint system +func NewSystem(scalarField *big.Int, capacity int, t SystemType) System { + system := System{ + Type: t, + SymbolTable: debug.NewSymbolTable(), + MDebug: map[int]int{}, + GnarkVersion: gnark.Version.String(), + ScalarField: scalarField.Text(16), + MHintsDependencies: make(map[solver.HintID]string), + q: new(big.Int).Set(scalarField), + bitLen: scalarField.BitLen(), + Instructions: make([]PackedInstruction, 0, capacity), + CallData: make([]uint32, 0, capacity*8), + lbOutputs: make([]uint32, 0, 256), + lbWireLevel: make([]int, 0, capacity), + Levels: make([][]int, 0, capacity/2), + CommitmentInfo: NewCommitments(t), + } + + system.genericHint = system.AddBlueprint(&BlueprintGenericHint{}) + return system +} + +// GetNbInstructions returns the number of instructions in the system +func (system *System) GetNbInstructions() int { + return len(system.Instructions) +} + +// GetInstruction returns the instruction at index id +func (system *System) GetInstruction(id int) Instruction { + return system.Instructions[id].Unpack(system) +} + +// AddBlueprint adds a blueprint to the system and returns its ID +func (system *System) AddBlueprint(b Blueprint) BlueprintID { + system.Blueprints = append(system.Blueprints, b) + return BlueprintID(len(system.Blueprints) - 1) +} + +func (system *System) GetNbSecretVariables() int { + return len(system.Secret) +} +func (system *System) GetNbPublicVariables() int { + return len(system.Public) +} +func (system *System) GetNbInternalVariables() int { + return system.NbInternalVariables +} + +// CheckSerializationHeader parses the scalar field and gnark version headers +// +// This is meant to be use at the deserialization step, and will error for illegal values +func (system *System) CheckSerializationHeader() error { + // check gnark version + binaryVersion := gnark.Version + objectVersion, err := semver.Parse(system.GnarkVersion) + if err != nil { + return fmt.Errorf("when parsing gnark version: %w", err) + } + + if binaryVersion.Compare(objectVersion) != 0 { + log := logger.Logger() + log.Warn().Str("binary", binaryVersion.String()).Str("object", objectVersion.String()).Msg("gnark version (binary) mismatch with constraint system. there are no guarantees on compatibilty") + } + + // TODO @gbotrel maintain version changes and compare versions properly + // (ie if major didn't change,we shouldn't have a compatibility issue) + + scalarField := new(big.Int) + _, ok := scalarField.SetString(system.ScalarField, 16) + if !ok { + return fmt.Errorf("when parsing serialized modulus: %s", system.ScalarField) + } + curveID := utils.FieldToCurve(scalarField) + if curveID == ecc.UNKNOWN && scalarField.Cmp(tinyfield.Modulus()) != 0 { + return fmt.Errorf("unsupported scalar field %s", scalarField.Text(16)) + } + system.q = new(big.Int).Set(scalarField) + system.bitLen = system.q.BitLen() + return nil +} + +// GetNbVariables return number of internal, secret and public variables +func (system *System) GetNbVariables() (internal, secret, public int) { + return system.NbInternalVariables, system.GetNbSecretVariables(), system.GetNbPublicVariables() +} + +func (system *System) Field() *big.Int { + return new(big.Int).Set(system.q) +} + +// bitLen returns the number of bits needed to represent a fr.Element +func (system *System) FieldBitLen() int { + return system.bitLen +} + +func (system *System) AddInternalVariable() (idx int) { + idx = system.NbInternalVariables + system.GetNbPublicVariables() + system.GetNbSecretVariables() + system.NbInternalVariables++ + return idx +} + +func (system *System) AddPublicVariable(name string) (idx int) { + idx = system.GetNbPublicVariables() + system.Public = append(system.Public, name) + return idx +} + +func (system *System) AddSecretVariable(name string) (idx int) { + idx = system.GetNbSecretVariables() + system.GetNbPublicVariables() + system.Secret = append(system.Secret, name) + return idx +} + +func (system *System) AddSolverHint(f solver.Hint, id solver.HintID, input []LinearExpression, nbOutput int) (internalVariables []int, err error) { + if nbOutput <= 0 { + return nil, fmt.Errorf("hint function must return at least one output") + } + + var name string + if id == solver.GetHintID(f) { + name = solver.GetHintName(f) + } else { + name = strconv.Itoa(int(id)) + } + + // register the hint as dependency + if registeredName, ok := system.MHintsDependencies[id]; ok { + // hint already registered, let's ensure string registeredName matches + if registeredName != name { + return nil, fmt.Errorf("hint dependency registration failed; %s previously register with same UUID as %s", name, registeredName) + } + } else { + system.MHintsDependencies[id] = name + } + + // prepare wires + internalVariables = make([]int, nbOutput) + for i := 0; i < len(internalVariables); i++ { + internalVariables[i] = system.AddInternalVariable() + } + + // associate these wires with the solver hint + hm := HintMapping{ + HintID: id, + Inputs: input, + OutputRange: struct { + Start uint32 + End uint32 + }{ + uint32(internalVariables[0]), + uint32(internalVariables[len(internalVariables)-1]) + 1, + }, + } + + blueprint := system.Blueprints[system.genericHint] + + // get []uint32 from the pool + calldata := getBuffer() + + blueprint.(BlueprintHint).CompressHint(hm, calldata) + + system.AddInstruction(system.genericHint, *calldata) + + // return []uint32 to the pool + putBuffer(calldata) + + return +} + +func (system *System) AddCommitment(c Commitment) error { + switch v := c.(type) { + case Groth16Commitment: + system.CommitmentInfo = append(system.CommitmentInfo.(Groth16Commitments), v) + case PlonkCommitment: + system.CommitmentInfo = append(system.CommitmentInfo.(PlonkCommitments), v) + default: + return fmt.Errorf("unknown commitment type %T", v) + } + return nil +} + +func (system *System) AddLog(l LogEntry) { + system.Logs = append(system.Logs, l) +} + +func (system *System) AttachDebugInfo(debugInfo DebugInfo, constraintID []int) { + system.DebugInfo = append(system.DebugInfo, LogEntry(debugInfo)) + id := len(system.DebugInfo) - 1 + for _, cID := range constraintID { + system.MDebug[cID] = id + } +} + +// VariableToString implements Resolver +func (system *System) VariableToString(vID int) string { + nbPublic := system.GetNbPublicVariables() + nbSecret := system.GetNbSecretVariables() + + if vID < nbPublic { + return system.Public[vID] + } + vID -= nbPublic + if vID < nbSecret { + return system.Secret[vID] + } + vID -= nbSecret + return fmt.Sprintf("v%d", vID) // TODO @gbotrel vs strconv.Itoa. +} + +func (cs *System) AddR1C(c R1C, bID BlueprintID) int { + profile.RecordConstraint() + + blueprint := cs.Blueprints[bID] + + // get a []uint32 from a pool + calldata := getBuffer() + + // compress the R1C into a []uint32 and add the instruction + blueprint.(BlueprintR1C).CompressR1C(&c, calldata) + cs.AddInstruction(bID, *calldata) + + // release the []uint32 to the pool + putBuffer(calldata) + + return cs.NbConstraints - 1 +} + +func (cs *System) AddSparseR1C(c SparseR1C, bID BlueprintID) int { + profile.RecordConstraint() + + blueprint := cs.Blueprints[bID] + + // get a []uint32 from a pool + calldata := getBuffer() + + // compress the SparceR1C into a []uint32 and add the instruction + blueprint.(BlueprintSparseR1C).CompressSparseR1C(&c, calldata) + + cs.AddInstruction(bID, *calldata) + + // release the []uint32 to the pool + putBuffer(calldata) + + return cs.NbConstraints - 1 +} + +func (cs *System) AddInstruction(bID BlueprintID, calldata []uint32) []uint32 { + // set the offsets + pi := PackedInstruction{ + StartCallData: uint64(len(cs.CallData)), + ConstraintOffset: uint32(cs.NbConstraints), + WireOffset: uint32(cs.NbInternalVariables + cs.GetNbPublicVariables() + cs.GetNbSecretVariables()), + BlueprintID: bID, + } + + // append the call data + cs.CallData = append(cs.CallData, calldata...) + + // update the total number of constraints + blueprint := cs.Blueprints[pi.BlueprintID] + cs.NbConstraints += blueprint.NbConstraints() + + // add the output wires + inst := pi.Unpack(cs) + nbOutputs := blueprint.NbOutputs(inst) + var wires []uint32 + for i := 0; i < nbOutputs; i++ { + wires = append(wires, uint32(cs.AddInternalVariable())) + } + + // add the instruction + cs.Instructions = append(cs.Instructions, pi) + + // update the instruction dependency tree + cs.updateLevel(len(cs.Instructions)-1, blueprint.WireWalker(inst)) + + return wires +} + +// GetNbConstraints returns the number of constraints +func (cs *System) GetNbConstraints() int { + return cs.NbConstraints +} + +func (cs *System) CheckUnconstrainedWires() error { + // TODO @gbotrel + return nil +} + +func (cs *System) GetR1CIterator() R1CIterator { + return R1CIterator{cs: cs} +} + +func (cs *System) GetSparseR1CIterator() SparseR1CIterator { + return SparseR1CIterator{cs: cs} +} + +func (cs *System) GetCommitments() Commitments { + return cs.CommitmentInfo +} + +// bufPool is a pool of buffers used by getBuffer and putBuffer. +// It is used to avoid allocating buffers for each constraint. +var bufPool = sync.Pool{ + New: func() interface{} { + r := make([]uint32, 0, 20) + return &r + }, +} + +// getBuffer returns a buffer of at least the given size. +// The buffer is taken from the pool if it is large enough, +// otherwise a new buffer is allocated. +// Caller must call putBuffer when done with the buffer. +func getBuffer() *[]uint32 { + to := bufPool.Get().(*[]uint32) + *to = (*to)[:0] + return to +} + +// putBuffer returns a buffer to the pool. +func putBuffer(buf *[]uint32) { + if buf == nil { + panic("invalid entry in putBuffer") + } + bufPool.Put(buf) +} + +func (system *System) AddGkr(gkr GkrInfo) error { + if system.GkrInfo.Is() { + return fmt.Errorf("currently only one GKR sub-circuit per SNARK is supported") + } + + system.GkrInfo = gkr + return nil +} diff --git a/constraint/field.go b/constraint/field.go new file mode 100644 index 0000000000..f63bc910e2 --- /dev/null +++ b/constraint/field.go @@ -0,0 +1,53 @@ +package constraint + +import ( + "encoding/binary" + "math/big" +) + +// Element represents a term coefficient data. It is instantiated by the concrete +// constraint system implementation. +// Most of the scalar field used in gnark are on 4 uint64, so we have a clear memory overhead here. +type Element [6]uint64 + +// IsZero returns true if coefficient == 0 +func (z *Element) IsZero() bool { + return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 +} + +// Bytes return the Element as a big-endian byte slice +func (z *Element) Bytes() [48]byte { + var b [48]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + return b +} + +// SetBytes sets the Element from a big-endian byte slice +func (z *Element) SetBytes(b [48]byte) { + z[0] = binary.BigEndian.Uint64(b[40:48]) + z[1] = binary.BigEndian.Uint64(b[32:40]) + z[2] = binary.BigEndian.Uint64(b[24:32]) + z[3] = binary.BigEndian.Uint64(b[16:24]) + z[4] = binary.BigEndian.Uint64(b[8:16]) + z[5] = binary.BigEndian.Uint64(b[0:8]) +} + +// Field capability to perform arithmetic on Coeff +type Field interface { + FromInterface(interface{}) Element + ToBigInt(Element) *big.Int + Mul(a, b Element) Element + Add(a, b Element) Element + Sub(a, b Element) Element + Neg(a Element) Element + Inverse(a Element) (Element, bool) + One() Element + IsOne(Element) bool + String(Element) string + Uint64(Element) (uint64, bool) +} diff --git a/constraint/gkr.go b/constraint/gkr.go new file mode 100644 index 0000000000..4d84a5466d --- /dev/null +++ b/constraint/gkr.go @@ -0,0 +1,158 @@ +package constraint + +import ( + "fmt" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "sort" +) + +type GkrVariable int // Just an alias to hide implementation details. May be more trouble than worth + +type InputDependency struct { + OutputWire int + OutputInstance int + InputInstance int +} + +type GkrWire struct { + Gate string // TODO: Change to description + Inputs []int + Dependencies []InputDependency // nil for input wires + NbUniqueOutputs int +} + +type GkrCircuit []GkrWire + +type GkrInfo struct { + Circuit GkrCircuit + MaxNIns int + NbInstances int + HashName string + SolveHintID solver.HintID + ProveHintID solver.HintID +} + +type GkrPermutations struct { + SortedInstances []int + SortedWires []int + InstancesPermutation []int + WiresPermutation []int +} + +func (w GkrWire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w GkrWire) IsOutput() bool { + return w.NbUniqueOutputs == 0 +} + +// AssignmentOffsets returns the index of the first value assigned to a wire TODO: Explain clearly +func (d *GkrInfo) AssignmentOffsets() []int { + c := d.Circuit + res := make([]int, len(c)+1) + for i := range c { + nbExplicitAssignments := 0 + if c[i].IsInput() { + nbExplicitAssignments = d.NbInstances - len(c[i].Dependencies) + } + res[i+1] = res[i] + nbExplicitAssignments + } + return res +} + +func (d *GkrInfo) NewInputVariable() GkrVariable { + i := len(d.Circuit) + d.Circuit = append(d.Circuit, GkrWire{}) + return GkrVariable(i) +} + +// Compile sorts the circuit wires, their dependencies and the instances +func (d *GkrInfo) Compile(nbInstances int) (GkrPermutations, error) { + + var p GkrPermutations + d.NbInstances = nbInstances + // sort the instances to decide the order in which they are to be solved + instanceDeps := make([][]int, nbInstances) + for i := range d.Circuit { + for _, dep := range d.Circuit[i].Dependencies { + instanceDeps[dep.InputInstance] = append(instanceDeps[dep.InputInstance], dep.OutputInstance) + } + } + + p.SortedInstances, _ = algo_utils.TopologicalSort(instanceDeps) + p.InstancesPermutation = algo_utils.InvertPermutation(p.SortedInstances) + + // this whole circuit sorting is a bit of a charade. if things are built using an api, there's no way it could NOT already be topologically sorted + // worth keeping for future-proofing? + + inputs := algo_utils.Map(d.Circuit, func(w GkrWire) []int { + return w.Inputs + }) + + var uniqueOuts [][]int + p.SortedWires, uniqueOuts = algo_utils.TopologicalSort(inputs) + p.WiresPermutation = algo_utils.InvertPermutation(p.SortedWires) + wirePermutationAt := algo_utils.SliceAt(p.WiresPermutation) + sorted := make([]GkrWire, len(d.Circuit)) // TODO: Directly manipulate d.Circuit instead + for newI, oldI := range p.SortedWires { + oldW := d.Circuit[oldI] + + if !oldW.IsInput() { + d.MaxNIns = utils.Max(d.MaxNIns, len(oldW.Inputs)) + } + + for j := range oldW.Dependencies { + dep := &oldW.Dependencies[j] + dep.OutputWire = p.WiresPermutation[dep.OutputWire] + dep.InputInstance = p.InstancesPermutation[dep.InputInstance] + dep.OutputInstance = p.InstancesPermutation[dep.OutputInstance] + } + sort.Slice(oldW.Dependencies, func(i, j int) bool { + return oldW.Dependencies[i].InputInstance < oldW.Dependencies[j].InputInstance + }) + for i := 1; i < len(oldW.Dependencies); i++ { + if oldW.Dependencies[i].InputInstance == oldW.Dependencies[i-1].InputInstance { + return p, fmt.Errorf("an input wire can only have one dependency per instance") + } + } // TODO: Check that dependencies and explicit assignments cover all instances + + sorted[newI] = GkrWire{ + Gate: oldW.Gate, + Inputs: algo_utils.Map(oldW.Inputs, wirePermutationAt), + Dependencies: oldW.Dependencies, + NbUniqueOutputs: len(uniqueOuts[oldI]), + } + } + d.Circuit = sorted + + return p, nil +} + +func (d *GkrInfo) Is() bool { + return d.Circuit != nil +} + +// Chunks returns intervals of instances that are independent of each other and can be solved in parallel +func (c GkrCircuit) Chunks(nbInstances int) []int { + res := make([]int, 0, 1) + lastSeenDependencyI := make([]int, len(c)) + + for start, end := 0, 0; start != nbInstances; start = end { + end = nbInstances + endWireI := -1 + for wI, w := range c { + if wDepI := lastSeenDependencyI[wI]; wDepI < len(w.Dependencies) && w.Dependencies[wDepI].InputInstance < end { + end = w.Dependencies[wDepI].InputInstance + endWireI = wI + } + } + if endWireI != -1 { + lastSeenDependencyI[endWireI]++ + } + res = append(res, end) + } + return res +} diff --git a/constraint/hint.go b/constraint/hint.go index 0e92d6afd6..ac1dbf96ba 100644 --- a/constraint/hint.go +++ b/constraint/hint.go @@ -1,14 +1,14 @@ package constraint import ( - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" ) -// Hint represents a solver hint -// it enables the solver to compute a Wire with a function provided at solving time -// using pre-defined inputs -type Hint struct { - ID hint.ID // hint function id - Inputs []LinearExpression // terms to inject in the hint function - Wires []int // IDs of wires the hint outputs map to +// HintMapping mark a list of output variables to be computed using provided hint and inputs. +type HintMapping struct { + HintID solver.HintID // Hint function id + Inputs []LinearExpression // Terms to inject in the hint function + OutputRange struct { // IDs of wires the hint outputs map to + Start, End uint32 + } } diff --git a/constraint/level_builder.go b/constraint/level_builder.go index 80c24614b8..e6cc47cbe6 100644 --- a/constraint/level_builder.go +++ b/constraint/level_builder.go @@ -8,34 +8,31 @@ package constraint // // We build a graph of dependency; we say that a wire is solved at a level l // --> l = max(level_of_dependencies(wire)) + 1 -func (system *System) updateLevel(cID int, c Iterable) { - system.lbOutputs = system.lbOutputs[:0] - system.lbHints = map[*Hint]struct{}{} +func (system *System) updateLevel(iID int, walkWires func(cb func(wire uint32))) { level := -1 - wireIterator := c.WireIterator() - for wID := wireIterator(); wID != -1; wID = wireIterator() { - // iterate over all wires of the R1C - system.processWire(uint32(wID), &level) - } + + // process all wires of the instruction + walkWires(func(wire uint32) { + system.processWire(wire, &level) + }) // level = max(dependencies) + 1 level++ // mark output wire with level for _, wireID := range system.lbOutputs { - for int(wireID) >= len(system.lbWireLevel) { - // we didn't encounter this wire yet, we need to grow b.wireLevels - system.lbWireLevel = append(system.lbWireLevel, -1) - } system.lbWireLevel[wireID] = level } // we can't skip levels, so appending is fine. if level >= len(system.Levels) { - system.Levels = append(system.Levels, []int{cID}) + system.Levels = append(system.Levels, []int{iID}) } else { - system.Levels[level] = append(system.Levels[level], cID) + system.Levels[level] = append(system.Levels[level], iID) } + // clean the table. NB! Do not remove or move, this is required to make the + // compilation deterministic. + system.lbOutputs = system.lbOutputs[:0] } func (system *System) processWire(wireID uint32, maxLevel *int) { @@ -53,32 +50,6 @@ func (system *System) processWire(wireID uint32, maxLevel *int) { } return } - // we don't know how to solve this wire; it's either THE wire we have to solve or a hint. - if h, ok := system.MHints[int(wireID)]; ok { - // check that we didn't process that hint already; performance wise, if many wires in a - // constraint are the output of the same hint, and input to parent hint are themselves - // computed with a hint, we can suffer. - // (nominal case: not too many different hints involved for a single constraint) - if _, ok := system.lbHints[h]; ok { - // skip - return - } - system.lbHints[h] = struct{}{} - - for _, hwid := range h.Wires { - system.lbOutputs = append(system.lbOutputs, uint32(hwid)) - } - for _, in := range h.Inputs { - for _, t := range in { - if !t.IsConstant() { - system.processWire(t.VID, maxLevel) - } - } - } - - return - } - - // it's the missing wire + // this wire is an output to the instruction system.lbOutputs = append(system.lbOutputs, wireID) } diff --git a/constraint/level_builder_test.go b/constraint/level_builder_test.go new file mode 100644 index 0000000000..c7489409e1 --- /dev/null +++ b/constraint/level_builder_test.go @@ -0,0 +1,42 @@ +package constraint_test + +import ( + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +func idHint(_ *big.Int, in []*big.Int, out []*big.Int) error { + if len(in) != len(out) { + return fmt.Errorf("in/out length mismatch %d≠%d", len(in), len(out)) + } + for i := range in { + out[i].Set(in[i]) + } + return nil +} + +type idHintCircuit struct { + X frontend.Variable +} + +func (c *idHintCircuit) Define(api frontend.API) error { + x, err := api.Compiler().NewHint(idHint, 1, c.X) + if err != nil { + return err + } + api.AssertIsEqual(x[0], c.X) + return nil +} + +func TestIdHint(t *testing.T) { + solver.RegisterHint(idHint) + assignment := idHintCircuit{0} + test.NewAssert(t).SolvingSucceeded(&idHintCircuit{}, &assignment, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254)) +} diff --git a/constraint/linear_expression.go b/constraint/linear_expression.go index 32a66b7843..20fb9e9986 100644 --- a/constraint/linear_expression.go +++ b/constraint/linear_expression.go @@ -29,3 +29,10 @@ func (l LinearExpression) String(r Resolver) string { sbb.WriteLinearExpression(l) return sbb.String() } + +func (l LinearExpression) Compress(to *[]uint32) { + (*to) = append((*to), uint32(len(l))) + for i := 0; i < len(l); i++ { + (*to) = append((*to), l[i].CID, l[i].VID) + } +} diff --git a/constraint/r1cs.go b/constraint/r1cs.go index 02c327e30e..b58ff95d40 100644 --- a/constraint/r1cs.go +++ b/constraint/r1cs.go @@ -14,161 +14,143 @@ package constraint -import ( - "errors" - "strconv" - "strings" - - "github.com/consensys/gnark/logger" -) - type R1CS interface { ConstraintSystem - // AddConstraint adds a constraint to the system and returns its id + // AddR1C adds a constraint to the system and returns its id // This does not check for validity of the constraint. - // If a debugInfo parameter is provided, it will be appended to the debug info structure - // and will grow the memory usage of the constraint system. - AddConstraint(r1c R1C, debugInfo ...DebugInfo) int + AddR1C(r1c R1C, bID BlueprintID) int - // GetConstraints return the list of R1C and a helper for pretty printing. + // GetR1Cs return the list of R1C // See StringBuilder for more info. // ! this is an experimental API. - GetConstraints() ([]R1C, Resolver) -} - -// R1CS describes a set of R1C constraint -type R1CSCore struct { - System - Constraints []R1C -} + GetR1Cs() []R1C -// GetNbConstraints returns the number of constraints -func (r1cs *R1CSCore) GetNbConstraints() int { - return len(r1cs.Constraints) + // GetR1CIterator returns an R1CIterator to iterate on the R1C constraints of the system. + GetR1CIterator() R1CIterator } -func (r1cs *R1CSCore) UpdateLevel(cID int, c Iterable) { - r1cs.updateLevel(cID, c) +// R1CIterator facilitates iterating through R1C constraints. +type R1CIterator struct { + R1C + cs *System + n int } -// IsValid perform post compilation checks on the Variables -// -// 1. checks that all user inputs are referenced in at least one constraint -// 2. checks that all hints are constrained -func (r1cs *R1CSCore) CheckUnconstrainedWires() error { - - // TODO @gbotrel add unit test for that. - - inputConstrained := make([]bool, r1cs.GetNbSecretVariables()+r1cs.GetNbPublicVariables()) - // one wire does not need to be constrained - inputConstrained[0] = true - cptInputs := len(inputConstrained) - 1 // marking 1 wire as already constrained // TODO @gbotrel check that - if cptInputs == 0 { - return errors.New("invalid constraint system: no input defined") - } - - cptHints := len(r1cs.MHints) - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the linear expressions and mark our inputs / hints as constrained - processLinearExpression := func(l LinearExpression) { - for _, t := range l { - if t.CoeffID() == CoeffIdZero { - // ignore zero coefficient, as it does not constraint the Variable - // though, we may want to flag that IF the Variable doesn't appear else where - continue - } - vID := t.WireID() - if vID < len(inputConstrained) { - if !inputConstrained[vID] { - inputConstrained[vID] = true - cptInputs-- - } - } else { - // internal variable, let's check if it's a hint - if _, ok := r1cs.MHints[vID]; ok { - if !mHintsConstrained[vID] { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - - } - } - for _, r1c := range r1cs.Constraints { - processLinearExpression(r1c.L) - processLinearExpression(r1c.R) - processLinearExpression(r1c.O) - - if cptHints|cptInputs == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptInputs != 0 { - sbb.WriteString(strconv.Itoa(cptInputs)) - sbb.WriteString(" unconstrained input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { - if !inputConstrained[i] { - if i < len(r1cs.Public) { - sbb.WriteString(r1cs.Public[i]) - } else { - sbb.WriteString(r1cs.Secret[i-len(r1cs.Public)]) - } - - sbb.WriteByte('\n') - cptInputs-- - } - } - sbb.WriteByte('\n') - return errors.New(sbb.String()) - } - - if cptHints != 0 { - // TODO @gbotrel @ivokub investigate --> emulated hints seems to go in this path a lot. - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints; i.e. wire created through NewHint() but doesn't not appear in the constraint system") - sbb.WriteByte('\n') - log := logger.Logger() - log.Warn().Err(errors.New(sbb.String())).Send() +// Next returns the next R1C or nil if end. Caller must not store the result since the +// same memory space is re-used for subsequent calls to Next. +func (it *R1CIterator) Next() *R1C { + if it.n >= it.cs.GetNbInstructions() { return nil - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) } - return errors.New(sbb.String()) + inst := it.cs.Instructions[it.n] + it.n++ + blueprint := it.cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(BlueprintR1C); ok { + bc.DecompressR1C(&it.R1C, inst.Unpack(it.cs)) + return &it.R1C + } + return it.Next() } +// // IsValid perform post compilation checks on the Variables +// // +// // 1. checks that all user inputs are referenced in at least one constraint +// // 2. checks that all hints are constrained +// func (r1cs *R1CSCore) CheckUnconstrainedWires() error { +// return nil + +// // TODO @gbotrel add unit test for that. + +// inputConstrained := make([]bool, r1cs.GetNbSecretVariables()+r1cs.GetNbPublicVariables()) +// // one wire does not need to be constrained +// inputConstrained[0] = true +// cptInputs := len(inputConstrained) - 1 // marking 1 wire as already constrained // TODO @gbotrel check that +// if cptInputs == 0 { +// return errors.New("invalid constraint system: no input defined") +// } + +// cptHints := len(r1cs.MHints) +// mHintsConstrained := make(map[int]bool) + +// // for each constraint, we check the linear expressions and mark our inputs / hints as constrained +// processLinearExpression := func(l LinearExpression) { +// for _, t := range l { +// if t.CoeffID() == CoeffIdZero { +// // ignore zero coefficient, as it does not constraint the Variable +// // though, we may want to flag that IF the Variable doesn't appear else where +// continue +// } +// vID := t.WireID() +// if vID < len(inputConstrained) { +// if !inputConstrained[vID] { +// inputConstrained[vID] = true +// cptInputs-- +// } +// } else { +// // internal variable, let's check if it's a hint +// if _, ok := r1cs.MHints[vID]; ok { +// if !mHintsConstrained[vID] { +// mHintsConstrained[vID] = true +// cptHints-- +// } +// } +// } + +// } +// } +// for _, r1c := range r1cs.Constraints { +// processLinearExpression(r1c.L) +// processLinearExpression(r1c.R) +// processLinearExpression(r1c.O) + +// if cptHints|cptInputs == 0 { +// return nil // we can stop. +// } + +// } + +// // something is a miss, we build the error string +// var sbb strings.Builder +// if cptInputs != 0 { +// sbb.WriteString(strconv.Itoa(cptInputs)) +// sbb.WriteString(" unconstrained input(s):") +// sbb.WriteByte('\n') +// for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { +// if !inputConstrained[i] { +// if i < len(r1cs.Public) { +// sbb.WriteString(r1cs.Public[i]) +// } else { +// sbb.WriteString(r1cs.Secret[i-len(r1cs.Public)]) +// } + +// sbb.WriteByte('\n') +// cptInputs-- +// } +// } +// sbb.WriteByte('\n') +// return errors.New(sbb.String()) +// } + +// if cptHints != 0 { +// // TODO @gbotrel @ivokub investigate --> emulated hints seems to go in this path a lot. +// sbb.WriteString(strconv.Itoa(cptHints)) +// sbb.WriteString(" unconstrained hints; i.e. wire created through NewHint() but doesn't not appear in the constraint system") +// sbb.WriteByte('\n') +// log := logger.Logger() +// log.Warn().Err(errors.New(sbb.String())).Send() +// return nil +// // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some +// // debugInfo to find where a hint was declared (and not constrained) +// } +// return errors.New(sbb.String()) +// } + // R1C used to compute the wires type R1C struct { L, R, O LinearExpression } -// WireIterator implements constraint.Iterable -func (r1c *R1C) WireIterator() func() int { - curr := 0 - return func() int { - if curr < len(r1c.L) { - curr++ - return r1c.L[curr-1].WireID() - } - if curr < len(r1c.L)+len(r1c.R) { - curr++ - return r1c.R[curr-1-len(r1c.L)].WireID() - } - if curr < len(r1c.L)+len(r1c.R)+len(r1c.O) { - curr++ - return r1c.O[curr-1-len(r1c.L)-len(r1c.R)].WireID() - } - return -1 - } -} - // String formats a R1C as L⋅R == O func (r1c *R1C) String(r Resolver) string { sbb := NewStringBuilder(r) diff --git a/constraint/r1cs_sparse.go b/constraint/r1cs_sparse.go index 4d99c82857..742521e9d5 100644 --- a/constraint/r1cs_sparse.go +++ b/constraint/r1cs_sparse.go @@ -14,160 +14,153 @@ package constraint -import ( - "errors" - "strconv" - "strings" -) - type SparseR1CS interface { ConstraintSystem - // AddConstraint adds a constraint to the sytem and returns its id - // This does not check for validity of the constraint. - // If a debugInfo parameter is provided, it will be appended to the debug info structure - // and will grow the memory usage of the constraint system. - AddConstraint(c SparseR1C, debugInfo ...DebugInfo) int + // AddSparseR1C adds a constraint to the constraint system. + AddSparseR1C(c SparseR1C, bID BlueprintID) int - // GetConstraints return the list of SparseR1C and a helper for pretty printing. + // GetSparseR1Cs return the list of SparseR1C // See StringBuilder for more info. // ! this is an experimental API. - GetConstraints() ([]SparseR1C, Resolver) -} - -// R1CS describes a set of SparseR1C constraint -// TODO @gbotrel maybe SparseR1CSCore and R1CSCore should go in code generation directly to avoid confusing this package. -type SparseR1CSCore struct { - System - Constraints []SparseR1C -} + GetSparseR1Cs() []SparseR1C -// GetNbConstraints returns the number of constraints -func (cs *SparseR1CSCore) GetNbConstraints() int { - return len(cs.Constraints) + // GetSparseR1CIterator returns an SparseR1CIterator to iterate on the SparseR1C constraints of the system. + GetSparseR1CIterator() SparseR1CIterator } -func (cs *SparseR1CSCore) UpdateLevel(cID int, c Iterable) { - cs.updateLevel(cID, c) +// SparseR1CIterator facilitates iterating through SparseR1C constraints. +type SparseR1CIterator struct { + SparseR1C + cs *System + n int } -func (system *SparseR1CSCore) CheckUnconstrainedWires() error { - // TODO @gbotrel add unit test for that. - - inputConstrained := make([]bool, system.GetNbSecretVariables()+system.GetNbPublicVariables()) - cptInputs := len(inputConstrained) - if cptInputs == 0 { - return errors.New("invalid constraint system: no input defined") +// Next returns the next SparseR1C or nil if end. Caller must not store the result since the +// same memory space is re-used for subsequent calls to Next. +func (it *SparseR1CIterator) Next() *SparseR1C { + if it.n >= it.cs.GetNbInstructions() { + return nil } - - cptHints := len(system.MHints) - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the terms and mark our inputs / hints as constrained - processTerm := func(t Term) { - - // L and M[0] handles the same wire but with a different coeff - vID := t.WireID() - if t.CoeffID() != CoeffIdZero { - if vID < len(inputConstrained) { - if !inputConstrained[vID] { - inputConstrained[vID] = true - cptInputs-- - } - } else { - // internal variable, let's check if it's a hint - if _, ok := system.MHints[vID]; ok { - vID -= (system.GetNbPublicVariables() + system.GetNbSecretVariables()) - if !mHintsConstrained[vID] { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - } - + inst := it.cs.Instructions[it.n] + it.n++ + blueprint := it.cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&it.SparseR1C, inst.Unpack(it.cs)) + return &it.SparseR1C } - for _, c := range system.Constraints { - processTerm(c.L) - processTerm(c.R) - processTerm(c.M[0]) - processTerm(c.M[1]) - processTerm(c.O) - if cptHints|cptInputs == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptInputs != 0 { - sbb.WriteString(strconv.Itoa(cptInputs)) - sbb.WriteString(" unconstrained input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { - if !inputConstrained[i] { - if i < len(system.Public) { - sbb.WriteString(system.Public[i]) - } else { - sbb.WriteString(system.Secret[i-len(system.Public)]) - } - sbb.WriteByte('\n') - cptInputs-- - } - } - sbb.WriteByte('\n') - } - - if cptHints != 0 { - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints") - sbb.WriteByte('\n') - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) - } - return errors.New(sbb.String()) + return it.Next() } -// SparseR1C used to compute the wires -// L+R+M[0]M[1]+O+k=0 -// if a Term is zero, it means the field doesn't exist (ex M=[0,0] means there is no multiplicative term) +// func (system *SparseR1CSCore) CheckUnconstrainedWires() error { +// // TODO @gbotrel add unit test for that. +// return nil +// inputConstrained := make([]bool, system.GetNbSecretVariables()+system.GetNbPublicVariables()) +// cptInputs := len(inputConstrained) +// if cptInputs == 0 { +// return errors.New("invalid constraint system: no input defined") +// } + +// cptHints := len(system.MHints) +// mHintsConstrained := make(map[int]bool) + +// // for each constraint, we check the terms and mark our inputs / hints as constrained +// processTerm := func(t Term) { + +// // L and M[0] handles the same wire but with a different coeff +// vID := t.WireID() +// if t.CoeffID() != CoeffIdZero { +// if vID < len(inputConstrained) { +// if !inputConstrained[vID] { +// inputConstrained[vID] = true +// cptInputs-- +// } +// } else { +// // internal variable, let's check if it's a hint +// if _, ok := system.MHints[vID]; ok { +// vID -= (system.GetNbPublicVariables() + system.GetNbSecretVariables()) +// if !mHintsConstrained[vID] { +// mHintsConstrained[vID] = true +// cptHints-- +// } +// } +// } +// } + +// } +// for _, c := range system.Constraints { +// processTerm(c.L) +// processTerm(c.R) +// processTerm(c.M[0]) +// processTerm(c.M[1]) +// processTerm(c.O) +// if cptHints|cptInputs == 0 { +// return nil // we can stop. +// } + +// } + +// // something is a miss, we build the error string +// var sbb strings.Builder +// if cptInputs != 0 { +// sbb.WriteString(strconv.Itoa(cptInputs)) +// sbb.WriteString(" unconstrained input(s):") +// sbb.WriteByte('\n') +// for i := 0; i < len(inputConstrained) && cptInputs != 0; i++ { +// if !inputConstrained[i] { +// if i < len(system.Public) { +// sbb.WriteString(system.Public[i]) +// } else { +// sbb.WriteString(system.Secret[i-len(system.Public)]) +// } +// sbb.WriteByte('\n') +// cptInputs-- +// } +// } +// sbb.WriteByte('\n') +// } + +// if cptHints != 0 { +// sbb.WriteString(strconv.Itoa(cptHints)) +// sbb.WriteString(" unconstrained hints") +// sbb.WriteByte('\n') +// // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some +// // debugInfo to find where a hint was declared (and not constrained) +// } +// return errors.New(sbb.String()) +// } + +type CommitmentConstraint uint32 + +const ( + NOT CommitmentConstraint = 0 + COMMITTED CommitmentConstraint = 1 + COMMITMENT CommitmentConstraint = 2 +) + +// SparseR1C represent a PlonK-ish constraint +// qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC -committed?*Bsb22Commitments-commitment?*commitmentValue == 0 type SparseR1C struct { - L, R, O Term - M [2]Term - K int // stores only the ID of the constant term that is used + XA, XB, XC uint32 + QL, QR, QO, QM, QC uint32 + Commitment CommitmentConstraint } -// WireIterator implements constraint.Iterable -func (c *SparseR1C) WireIterator() func() int { - curr := 0 - return func() int { - switch curr { - case 0: - curr++ - return c.L.WireID() - case 1: - curr++ - return c.R.WireID() - case 2: - curr++ - return c.O.WireID() - } - return -1 - } +func (c *SparseR1C) Clear() { + *c = SparseR1C{} } // String formats the constraint as qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0 func (c *SparseR1C) String(r Resolver) string { sbb := NewStringBuilder(r) - sbb.WriteTerm(c.L) + sbb.WriteTerm(Term{CID: c.QL, VID: c.XA}) sbb.WriteString(" + ") - sbb.WriteTerm(c.R) + sbb.WriteTerm(Term{CID: c.QR, VID: c.XB}) sbb.WriteString(" + ") - sbb.WriteTerm(c.O) - if qM := sbb.CoeffToString(c.M[0].CoeffID()); qM != "0" { - xa := sbb.VariableToString(c.M[0].WireID()) - xb := sbb.VariableToString(c.M[1].WireID()) + sbb.WriteTerm(Term{CID: c.QO, VID: c.XC}) + if qM := sbb.CoeffToString(int(c.QM)); qM != "0" { + xa := sbb.VariableToString(int(c.XA)) + xb := sbb.VariableToString(int(c.XB)) sbb.WriteString(" + ") sbb.WriteString(qM) sbb.WriteString("⋅(") @@ -177,7 +170,7 @@ func (c *SparseR1C) String(r Resolver) string { sbb.WriteByte(')') } sbb.WriteString(" + ") - sbb.WriteString(r.CoeffToString(c.K)) + sbb.WriteString(r.CoeffToString(int(c.QC))) sbb.WriteString(" == 0") return sbb.String() } diff --git a/constraint/r1cs_sparse_test.go b/constraint/r1cs_sparse_test.go index 024560ff93..5a0bfd647c 100644 --- a/constraint/r1cs_sparse_test.go +++ b/constraint/r1cs_sparse_test.go @@ -7,12 +7,13 @@ import ( cs "github.com/consensys/gnark/constraint/bn254" ) -func ExampleSparseR1CS_GetConstraints() { +func ExampleSparseR1CS_GetSparseR1Cs() { // build a constraint system; this is (usually) done by the frontend package // for this Example we want to manipulate the constraints and output a string representation // and build the linear expressions "manually". // note: R1CS apis are more mature; SparseR1CS apis are going to change in the next release(s). scs := cs.NewSparseR1CS(0) + blueprint := scs.AddBlueprint(&constraint.BlueprintGenericSparseR1C{}) Y := scs.AddPublicVariable("Y") X := scs.AddSecretVariable("X") @@ -20,40 +21,34 @@ func ExampleSparseR1CS_GetConstraints() { v0 := scs.AddInternalVariable() // X² // coefficients - cZero := scs.FromInterface(0) cOne := scs.FromInterface(1) - cMinusOne := scs.FromInterface(-1) cFive := scs.FromInterface(5) // X² == X * X - scs.AddConstraint(constraint.SparseR1C{ - L: scs.MakeTerm(&cZero, X), - R: scs.MakeTerm(&cZero, X), - O: scs.MakeTerm(&cMinusOne, v0), - M: [2]constraint.Term{ - scs.MakeTerm(&cOne, X), - scs.MakeTerm(&cOne, X), - }, - K: int(scs.MakeTerm(&cZero, 0).CID), - }) + scs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(X), + XB: uint32(X), + XC: uint32(v0), + QO: constraint.CoeffIdMinusOne, + QM: constraint.CoeffIdOne, + }, blueprint) // X² + 5X + 5 == Y - scs.AddConstraint(constraint.SparseR1C{ - R: scs.MakeTerm(&cOne, v0), - L: scs.MakeTerm(&cFive, X), - O: scs.MakeTerm(&cMinusOne, Y), - M: [2]constraint.Term{ - scs.MakeTerm(&cZero, v0), - scs.MakeTerm(&cZero, X), - }, - K: int(scs.MakeTerm(&cFive, 0).CID), - }) + scs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(X), + XB: uint32(v0), + XC: uint32(Y), + QO: constraint.CoeffIdMinusOne, + QL: scs.AddCoeff(cFive), + QR: scs.AddCoeff(cOne), + QC: scs.AddCoeff(cFive), + }, blueprint) // get the constraints - constraints, r := scs.GetConstraints() + constraints := scs.GetSparseR1Cs() for _, c := range constraints { - fmt.Println(c.String(r)) + fmt.Println(c.String(scs)) // for more granularity use constraint.NewStringBuilder(r) that embeds a string.Builder // and has WriteLinearExpression and WriteTerm methods. } diff --git a/constraint/r1cs_test.go b/constraint/r1cs_test.go index b2736d3073..9444b5ef53 100644 --- a/constraint/r1cs_test.go +++ b/constraint/r1cs_test.go @@ -3,16 +3,21 @@ package constraint_test import ( "fmt" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" ) -func ExampleR1CS_GetConstraints() { +func ExampleR1CS_GetR1Cs() { // build a constraint system; this is (usually) done by the frontend package // for this Example we want to manipulate the constraints and output a string representation // and build the linear expressions "manually". r1cs := cs.NewR1CS(0) + blueprint := r1cs.AddBlueprint(&constraint.BlueprintGenericR1C{}) + ONE := r1cs.AddPublicVariable("1") // the "ONE" wire Y := r1cs.AddPublicVariable("Y") X := r1cs.AddSecretVariable("X") @@ -25,35 +30,35 @@ func ExampleR1CS_GetConstraints() { cFive := r1cs.FromInterface(5) // X² == X * X - r1cs.AddConstraint(constraint.R1C{ - L: constraint.LinearExpression{r1cs.MakeTerm(&cOne, X)}, - R: constraint.LinearExpression{r1cs.MakeTerm(&cOne, X)}, - O: constraint.LinearExpression{r1cs.MakeTerm(&cOne, v0)}, - }) + r1cs.AddR1C(constraint.R1C{ + L: constraint.LinearExpression{r1cs.MakeTerm(cOne, X)}, + R: constraint.LinearExpression{r1cs.MakeTerm(cOne, X)}, + O: constraint.LinearExpression{r1cs.MakeTerm(cOne, v0)}, + }, blueprint) // X³ == X² * X - r1cs.AddConstraint(constraint.R1C{ - L: constraint.LinearExpression{r1cs.MakeTerm(&cOne, v0)}, - R: constraint.LinearExpression{r1cs.MakeTerm(&cOne, X)}, - O: constraint.LinearExpression{r1cs.MakeTerm(&cOne, v1)}, - }) + r1cs.AddR1C(constraint.R1C{ + L: constraint.LinearExpression{r1cs.MakeTerm(cOne, v0)}, + R: constraint.LinearExpression{r1cs.MakeTerm(cOne, X)}, + O: constraint.LinearExpression{r1cs.MakeTerm(cOne, v1)}, + }, blueprint) // Y == X³ + X + 5 - r1cs.AddConstraint(constraint.R1C{ - R: constraint.LinearExpression{r1cs.MakeTerm(&cOne, ONE)}, - L: constraint.LinearExpression{r1cs.MakeTerm(&cOne, Y)}, + r1cs.AddR1C(constraint.R1C{ + R: constraint.LinearExpression{r1cs.MakeTerm(cOne, ONE)}, + L: constraint.LinearExpression{r1cs.MakeTerm(cOne, Y)}, O: constraint.LinearExpression{ - r1cs.MakeTerm(&cFive, ONE), - r1cs.MakeTerm(&cOne, X), - r1cs.MakeTerm(&cOne, v1), + r1cs.MakeTerm(cFive, ONE), + r1cs.MakeTerm(cOne, X), + r1cs.MakeTerm(cOne, v1), }, - }) + }, blueprint) // get the constraints - constraints, r := r1cs.GetConstraints() + constraints := r1cs.GetR1Cs() for _, r1c := range constraints { - fmt.Println(r1c.String(r)) + fmt.Println(r1c.String(r1cs)) // for more granularity use constraint.NewStringBuilder(r) that embeds a string.Builder // and has WriteLinearExpression and WriteTerm methods. } @@ -63,3 +68,38 @@ func ExampleR1CS_GetConstraints() { // v0 ⋅ X == v1 // Y ⋅ 1 == 5 + X + v1 } + +func ExampleR1CS_Solve() { + // build a constraint system and a witness; + ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &cubic{}) + w, _ := frontend.NewWitness(&cubic{X: 3, Y: 35}, ecc.BN254.ScalarField()) + + _solution, _ := ccs.Solve(w) + + // concrete solution + solution := _solution.(*cs.R1CSSolution) + + // solution vector should have [1, 3, 35, 9, 27] + for _, v := range solution.W { + fmt.Println(v.String()) + } + + // Output: + // 1 + // 3 + // 35 + // 9 + // 27 +} + +type cubic struct { + X, Y frontend.Variable +} + +// Define declares the circuit constraints +// x**3 + x + 5 == y +func (circuit *cubic) Define(api frontend.API) error { + x3 := api.Mul(circuit.X, circuit.X, circuit.X) + api.AssertIsEqual(circuit.Y, api.Add(x3, circuit.X, 5)) + return nil +} diff --git a/constraint/solver/hint.go b/constraint/solver/hint.go new file mode 100644 index 0000000000..5ffc3f1f31 --- /dev/null +++ b/constraint/solver/hint.go @@ -0,0 +1,101 @@ +package solver + +import ( + "hash/fnv" + "math/big" + "reflect" + "runtime" +) + +// HintID is a unique identifier for a hint function used for lookup. +type HintID uint32 + +// Hint allows to define computations outside of a circuit. +// +// It defines an annotated hint function; the number of inputs and outputs injected at solving +// time is defined in the circuit (compile time). +// +// For example: +// +// b := api.NewHint(hint, 2, a) +// --> at solving time, hint is going to be invoked with 1 input (a) and is expected to return 2 outputs +// b[0] and b[1]. +// +// Usually, it is expected that computations in circuits are performed on +// variables. However, in some cases defining the computations in circuits may be +// complicated or computationally expensive. By using hints, the computations are +// performed outside of the circuit on integers (compared to the frontend.Variable +// values inside the circuits) and the result of a hint function is assigned to a +// newly created variable in a circuit. +// +// As the computations are performed outside of the circuit, then the correctness of +// the result is not guaranteed. This also means that the result of a hint function +// is unconstrained by default, leading to failure while composing circuit proof. +// Thus, it is the circuit developer responsibility to verify the correctness hint +// result by adding necessary constraints in the circuit. +// +// As an example, lets say the hint function computes a factorization of a +// semiprime n: +// +// p, q <- hint(n) st. p * q = n +// +// into primes p and q. Then, the circuit developer needs to assert in the circuit +// that p*q indeed equals to n: +// +// n == p * q. +// +// However, if the hint function is incorrectly defined (e.g. in the previous +// example, it returns 1 and n instead of p and q), then the assertion may still +// hold, but the constructed proof is semantically invalid. Thus, the user +// constructing the proof must be extremely cautious when using hints. +// +// # Using hint functions in circuits +// +// To use a hint function in a circuit, the developer first needs to define a hint +// function hintFn according to the Function interface. Then, in a circuit, the +// developer applies the hint function with frontend.API.NewHint(hintFn, vars...), +// where vars are the variables the hint function will be applied to (and +// correspond to the argument inputs in the Function type) which returns a new +// unconstrained variable. The returned variables must be constrained using +// frontend.API.Assert[.*] methods. +// +// As explained, the hints are essentially black boxes from the circuit point of +// view and thus the defined hints in circuits are not used when constructing a +// proof. To allow the particular hint functions to be used during proof +// construction, the user needs to supply a solver.Option indicating the +// enabled hints. Such options can be obtained by a call to +// solver.WithHints(hintFns...), where hintFns are the corresponding hint +// functions. +// +// # Using hint functions in gadgets +// +// Similar considerations apply for hint functions used in gadgets as in +// user-defined circuits. However, listing all hint functions used in a particular +// gadget for constructing solver.Option puts high overhead for the user to +// enable all necessary hints. +// +// For that, this package also provides a registry of trusted hint functions. When +// a gadget registers a hint function, then it is automatically enabled during +// proof computation and the prover does not need to provide a corresponding +// proving option. +// +// In the init() method of the gadget, call the method RegisterHint(hintFn) function on +// the hint function hintFn to register a hint function in the package registry. +type Hint func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error + +// GetHintID is a reference function for computing the hint ID based on a function name +func GetHintID(fn Hint) HintID { + hf := fnv.New32a() + name := GetHintName(fn) + + // TODO relying on name to derive UUID is risky; if fn is an anonymous func, wil be package.glob..funcN + // and if new anonymous functions are added in the package, N may change, so will UUID. + hf.Write([]byte(name)) // #nosec G104 -- does not err + + return HintID(hf.Sum32()) +} + +func GetHintName(fn Hint) string { + fnptr := reflect.ValueOf(fn).Pointer() + return runtime.FuncForPC(fnptr).Name() +} diff --git a/constraint/solver/hint_registry.go b/constraint/solver/hint_registry.go new file mode 100644 index 0000000000..7021de5024 --- /dev/null +++ b/constraint/solver/hint_registry.go @@ -0,0 +1,88 @@ +package solver + +import ( + "fmt" + "math/big" + "sync" + + "github.com/consensys/gnark/logger" +) + +func init() { + RegisterHint(InvZeroHint) +} + +var ( + registry = make(map[HintID]Hint) + registryM sync.RWMutex +) + +// RegisterHint registers a hint function in the global registry. +func RegisterHint(hintFns ...Hint) { + registryM.Lock() + defer registryM.Unlock() + for _, hintFn := range hintFns { + key := GetHintID(hintFn) + name := GetHintName(hintFn) + if _, ok := registry[key]; ok { + log := logger.Logger() + log.Warn().Str("name", name).Msg("function registered multiple times") + return + } + registry[key] = hintFn + } +} + +func GetRegisteredHint(key HintID) Hint { + return registry[key] +} + +func RegisterNamedHint(hintFn Hint, key HintID) { + registryM.Lock() + defer registryM.Unlock() + if _, ok := registry[key]; ok { + panic(fmt.Errorf("hint id %d already taken", key)) + } + registry[key] = hintFn +} + +// GetRegisteredHints returns all registered hint functions. +func GetRegisteredHints() []Hint { + registryM.RLock() + defer registryM.RUnlock() + ret := make([]Hint, 0, len(registry)) + for _, v := range registry { + ret = append(ret, v) + } + return ret +} + +func cloneMap[K comparable, V any](src map[K]V) map[K]V { + res := make(map[K]V, len(registry)) + for k, v := range src { + res[k] = v + } + return res +} + +func cloneHintRegistry() map[HintID]Hint { + registryM.Lock() + defer registryM.Unlock() + return cloneMap(registry) +} + +// InvZeroHint computes the value 1/a for the single input a. If a == 0, returns 0. +func InvZeroHint(q *big.Int, inputs []*big.Int, results []*big.Int) error { + result := results[0] + + // save input + result.Set(inputs[0]) + + // a == 0, return + if result.IsUint64() && result.Uint64() == 0 { + return nil + } + + result.ModInverse(result, q) + return nil +} diff --git a/constraint/solver/options.go b/constraint/solver/options.go new file mode 100644 index 0000000000..d7aa7cac4e --- /dev/null +++ b/constraint/solver/options.go @@ -0,0 +1,67 @@ +package solver + +import ( + "github.com/consensys/gnark/logger" + "github.com/rs/zerolog" +) + +// Option defines option for altering the behavior of a constraint system +// solver (Solve() method). See the descriptions of functions returning instances +// of this type for implemented options. +type Option func(*Config) error + +// Config is the configuration for the solver with the options applied. +type Config struct { + HintFunctions map[HintID]Hint // defaults to all built-in hint functions + Logger zerolog.Logger // defaults to gnark.Logger +} + +// WithHints is a solver option that specifies additional hint functions to be used +// by the constraint solver. +func WithHints(hintFunctions ...Hint) Option { + log := logger.Logger() + return func(opt *Config) error { + // it is an error to register hint function several times, but as the + // prover already checks it then omit here. + for _, h := range hintFunctions { + uuid := GetHintID(h) + if _, ok := opt.HintFunctions[uuid]; ok { + log.Warn().Int("hintID", int(uuid)).Str("name", GetHintName(h)).Msg("duplicate hint function") + } else { + opt.HintFunctions[uuid] = h + } + } + return nil + } +} + +// OverrideHint forces the solver to use provided hint function for given id. +func OverrideHint(id HintID, f Hint) Option { + return func(opt *Config) error { + opt.HintFunctions[id] = f + return nil + } +} + +// WithLogger is a prover option that specifies zerolog.Logger as a destination for the +// logs printed by api.Println(). By default, uses gnark/logger. +// zerolog.Nop() will disable logging +func WithLogger(l zerolog.Logger) Option { + return func(opt *Config) error { + opt.Logger = l + return nil + } +} + +// NewConfig returns a default SolverConfig with given prover options opts applied. +func NewConfig(opts ...Option) (Config, error) { + log := logger.Logger() + opt := Config{Logger: log} + opt.HintFunctions = cloneHintRegistry() + for _, option := range opts { + if err := option(&opt); err != nil { + return Config{}, err + } + } + return opt, nil +} diff --git a/constraint/system.go b/constraint/system.go index 2462f4aebb..e03586af4e 100644 --- a/constraint/system.go +++ b/constraint/system.go @@ -1,30 +1,29 @@ package constraint import ( - "fmt" "io" "math/big" - "github.com/blang/semver/v4" - "github.com/consensys/gnark" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/internal/tinyfield" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/constraint/solver" ) // ConstraintSystem interface that all constraint systems implement. type ConstraintSystem interface { io.WriterTo io.ReaderFrom - CoeffEngine + Field + Resolver + CustomizableSystem // IsSolved returns nil if given witness solves the constraint system and error otherwise - IsSolved(witness witness.Witness, opts ...backend.ProverOption) error + // Deprecated: use _, err := Solve(...) instead + IsSolved(witness witness.Witness, opts ...solver.Option) error + + // Solve attempts to solve the constraint system using provided witness. + // Returns an error if the witness does not allow all the constraints to be satisfied. + // Returns a typed solution (R1CSSolution or SparseR1CSSolution) and nil otherwise. + Solve(witness witness.Witness, opts ...solver.Option) (any, error) // GetNbVariables return number of internal, secret and public Variables // Deprecated: use GetNbSecretVariables() instead @@ -34,6 +33,7 @@ type ConstraintSystem interface { GetNbSecretVariables() int GetNbPublicVariables() int + GetNbInstructions() int GetNbConstraints() int GetNbCoefficients() int @@ -46,243 +46,45 @@ type ConstraintSystem interface { // AddSolverHint adds a hint to the solver such that the output variables will be computed // using a call to output := f(input...) at solve time. - AddSolverHint(f hint.Function, input []LinearExpression, nbOutput int) (internalVariables []int, err error) + // Providing the function f is optional. If it is provided, id will be ignored and one will be derived from f's name. + // Otherwise, the provided id will be used to register the hint with, + AddSolverHint(f solver.Hint, id solver.HintID, input []LinearExpression, nbOutput int) (internalVariables []int, err error) AddCommitment(c Commitment) error + GetCommitments() Commitments + AddGkr(gkr GkrInfo) error AddLog(l LogEntry) // MakeTerm returns a new Term. The constraint system may store coefficients in a map, so // calls to this function will grow the memory usage of the constraint system. - MakeTerm(coeff *Coeff, variableID int) Term + MakeTerm(coeff Element, variableID int) Term + + // AddCoeff adds a coefficient to the underlying constraint system. The system will not store duplicate, + // but is not purging for unused coeff either, so this grows memory usage. + AddCoeff(coeff Element) uint32 NewDebugInfo(errName string, i ...interface{}) DebugInfo // AttachDebugInfo enables attaching debug information to multiple constraints. - // This is more efficient than using the AddConstraint(.., debugInfo) since it will store the + // This is more efficient than using the AddR1C(.., debugInfo) since it will store the // debug information only once. AttachDebugInfo(debugInfo DebugInfo, constraintID []int) // CheckUnconstrainedWires returns and error if the constraint system has wires that are not uniquely constrained. // This is experimental. CheckUnconstrainedWires() error -} - -type Iterable interface { - // WireIterator returns a new iterator to iterate over the wires of the implementer (usually, a constraint) - // Call to next() returns the next wireID of the Iterable object and -1 when iteration is over. - // - // For example a R1C constraint with L, R, O linear expressions, each of size 2, calling several times - // next := r1c.WireIterator(); - // for wID := next(); wID != -1; wID = next() {} - // // will return in order L[0],L[1],R[0],R[1],O[0],O[1],-1 - WireIterator() (next func() int) -} - -var _ Iterable = &SparseR1C{} -var _ Iterable = &R1C{} - -// System contains core elements for a constraint System -type System struct { - // serialization header - GnarkVersion string - ScalarField string - - // number of internal wires - NbInternalVariables int - - // input wires names - Public, Secret []string - - // logs (added with system.Println, resolved when solver sets a value to a wire) - Logs []LogEntry - - // debug info contains stack trace (including line number) of a call to a system.API that - // results in an unsolved constraint - DebugInfo []LogEntry - SymbolTable debug.SymbolTable - // maps constraint id to debugInfo id - // several constraints may point to the same debug info - MDebug map[int]int - - MHints map[int]*Hint // maps wireID to hint - MHintsDependencies map[hint.ID]string // maps hintID to hint string identifier - - // each level contains independent constraints and can be parallelized - // it is guaranteed that all dependencies for constraints in a level l are solved - // in previous levels - // TODO @gbotrel these are currently updated after we add a constraint. - // but in case the object is built from a serialized representation - // we need to init the level builder lbWireLevel from the existing constraints. - Levels [][]int - - // scalar field - q *big.Int `cbor:"-"` - bitLen int `cbor:"-"` - - // level builder - lbWireLevel []int `cbor:"-"` // at which level we solve a wire. init at -1. - lbOutputs []uint32 `cbor:"-"` // wire outputs for current constraint. - lbHints map[*Hint]struct{} `cbor:"-"` // hints we processed in current round - - CommitmentInfo Commitment -} - -// NewSystem initialize the common structure among constraint system -func NewSystem(scalarField *big.Int) System { - return System{ - SymbolTable: debug.NewSymbolTable(), - MDebug: map[int]int{}, - GnarkVersion: gnark.Version.String(), - ScalarField: scalarField.Text(16), - MHints: make(map[int]*Hint), - MHintsDependencies: make(map[hint.ID]string), - q: new(big.Int).Set(scalarField), - bitLen: scalarField.BitLen(), - lbHints: map[*Hint]struct{}{}, - } -} - -func (system *System) GetNbSecretVariables() int { - return len(system.Secret) -} -func (system *System) GetNbPublicVariables() int { - return len(system.Public) -} -func (system *System) GetNbInternalVariables() int { - return system.NbInternalVariables -} - -// CheckSerializationHeader parses the scalar field and gnark version headers -// -// This is meant to be use at the deserialization step, and will error for illegal values -func (system *System) CheckSerializationHeader() error { - // check gnark version - binaryVersion := gnark.Version - objectVersion, err := semver.Parse(system.GnarkVersion) - if err != nil { - return fmt.Errorf("when parsing gnark version: %w", err) - } - - if binaryVersion.Compare(objectVersion) != 0 { - log := logger.Logger() - log.Warn().Str("binary", binaryVersion.String()).Str("object", objectVersion.String()).Msg("gnark version (binary) mismatch with constraint system. there are no guarantees on compatibilty") - } - - // TODO @gbotrel maintain version changes and compare versions properly - // (ie if major didn't change,we shouldn't have a compat issue) - scalarField := new(big.Int) - _, ok := scalarField.SetString(system.ScalarField, 16) - if !ok { - return fmt.Errorf("when parsing serialized modulus: %s", system.ScalarField) - } - curveID := utils.FieldToCurve(scalarField) - if curveID == ecc.UNKNOWN && scalarField.Cmp(tinyfield.Modulus()) != 0 { - return fmt.Errorf("unsupported scalard field %s", scalarField.Text(16)) - } - system.q = new(big.Int).Set(scalarField) - system.bitLen = system.q.BitLen() - return nil -} - -// GetNbVariables return number of internal, secret and public variables -func (system *System) GetNbVariables() (internal, secret, public int) { - return system.NbInternalVariables, system.GetNbSecretVariables(), system.GetNbPublicVariables() -} - -func (system *System) Field() *big.Int { - return new(big.Int).Set(system.q) -} - -// bitLen returns the number of bits needed to represent a fr.Element -func (system *System) FieldBitLen() int { - return system.bitLen -} - -func (system *System) AddInternalVariable() (idx int) { - idx = system.NbInternalVariables + system.GetNbPublicVariables() + system.GetNbSecretVariables() - system.NbInternalVariables++ - return idx -} - -func (system *System) AddPublicVariable(name string) (idx int) { - idx = system.GetNbPublicVariables() - system.Public = append(system.Public, name) - return idx -} - -func (system *System) AddSecretVariable(name string) (idx int) { - idx = system.GetNbSecretVariables() + system.GetNbPublicVariables() - system.Secret = append(system.Secret, name) - return idx -} - -func (system *System) AddSolverHint(f hint.Function, input []LinearExpression, nbOutput int) (internalVariables []int, err error) { - if nbOutput <= 0 { - return nil, fmt.Errorf("hint function must return at least one output") - } - - // register the hint as dependency - hintUUID, hintID := hint.UUID(f), hint.Name(f) - if id, ok := system.MHintsDependencies[hintUUID]; ok { - // hint already registered, let's ensure string id matches - if id != hintID { - return nil, fmt.Errorf("hint dependency registration failed; %s previously register with same UUID as %s", hintID, id) - } - } else { - system.MHintsDependencies[hintUUID] = hintID - } - - // prepare wires - internalVariables = make([]int, nbOutput) - for i := 0; i < len(internalVariables); i++ { - internalVariables[i] = system.AddInternalVariable() - } - - // associate these wires with the solver hint - ch := &Hint{ID: hintUUID, Inputs: input, Wires: internalVariables} - for _, vID := range internalVariables { - system.MHints[vID] = ch - } - - return -} - -func (system *System) AddCommitment(c Commitment) error { - if system.CommitmentInfo.Is() { - return fmt.Errorf("currently only one commitment per circuit is supported") - } - - system.CommitmentInfo = c - - return nil -} - -func (system *System) AddLog(l LogEntry) { - system.Logs = append(system.Logs, l) -} + GetInstruction(int) Instruction -func (system *System) AttachDebugInfo(debugInfo DebugInfo, constraintID []int) { - system.DebugInfo = append(system.DebugInfo, LogEntry(debugInfo)) - id := len(system.DebugInfo) - 1 - for _, cID := range constraintID { - system.MDebug[cID] = id - } + GetCoefficient(i int) Element } -// VariableToString implements Resolver -func (system *System) VariableToString(vID int) string { - nbPublic := system.GetNbPublicVariables() - nbSecret := system.GetNbSecretVariables() +type CustomizableSystem interface { + // AddBlueprint registers the given blueprint and returns its id. This should be called only once per blueprint. + AddBlueprint(b Blueprint) BlueprintID - if vID < nbPublic { - return system.Public[vID] - } - vID -= nbPublic - if vID < nbSecret { - return system.Secret[vID] - } - vID -= nbSecret - return fmt.Sprintf("v%d", vID) // TODO @gbotrel vs strconv.Itoa. + // AddInstruction adds an instruction to the system and returns a list of created wires + // if the blueprint declared any outputs. + AddInstruction(bID BlueprintID, calldata []uint32) []uint32 } diff --git a/constraint/term.go b/constraint/term.go index 182329e576..799fb4f7a3 100644 --- a/constraint/term.go +++ b/constraint/term.go @@ -18,6 +18,15 @@ import ( "math" ) +// ids of the coefficients with simple values in any cs.coeffs slice. +const ( + CoeffIdZero = iota + CoeffIdOne + CoeffIdTwo + CoeffIdMinusOne + CoeffIdMinusTwo +) + // Term represents a coeff * variable in a constraint system type Term struct { CID, VID uint32 @@ -44,3 +53,12 @@ func (t Term) String(r Resolver) string { sbb.WriteTerm(t) return sbb.String() } + +// implements constraint.Compressible + +// Compress compresses the term into a slice of uint32 words. +// For compatibility with test engine and LinearExpression, the term is encoded as: +// 1, CID, VID (i.e a LinearExpression with a single term) +func (t Term) Compress(to *[]uint32) { + (*to) = append((*to), 1, t.CID, t.VID) +} diff --git a/constraint/tinyfield/coeff.go b/constraint/tinyfield/coeff.go index 3bd799d820..3aec8da553 100644 --- a/constraint/tinyfield/coeff.go +++ b/constraint/tinyfield/coeff.go @@ -46,7 +46,7 @@ func newCoeffTable(capacity int) CoeffTable { } -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -69,7 +69,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } + return cID +} +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -78,7 +82,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } -var _ constraint.CoeffEngine = &arithEngine{} +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} var ( two fr.Element @@ -94,10 +101,7 @@ func init() { minusTwo.Neg(&two) } -// implements constraint.CoeffEngine -type arithEngine struct{} - -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -106,55 +110,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() } + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/tinyfield/r1cs.go b/constraint/tinyfield/r1cs.go deleted file mode 100644 index accfdaef50..0000000000 --- a/constraint/tinyfield/r1cs.go +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/fxamacker/cbor/v2" - "io" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - "github.com/consensys/gnark-crypto/ecc" - "math" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a, b, c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/tinyfield/r1cs_sparse.go b/constraint/tinyfield/r1cs_sparse.go deleted file mode 100644 index 4ad9f0023c..0000000000 --- a/constraint/tinyfield/r1cs_sparse.go +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/fxamacker/cbor/v2" - "io" - "math" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/witness" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution(nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i := 0; i < len(coefficientsNegInv); i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) (int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.tinyfield) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/constraint/tinyfield/r1cs_test.go b/constraint/tinyfield/r1cs_test.go index 09bb5f0968..5015c06972 100644 --- a/constraint/tinyfield/r1cs_test.go +++ b/constraint/tinyfield/r1cs_test.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "reflect" "testing" @@ -51,7 +52,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -79,10 +80,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -150,12 +151,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } - var w circuit w.X = 1 w.Y = 1 @@ -164,8 +159,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + } diff --git a/constraint/tinyfield/solution.go b/constraint/tinyfield/solution.go deleted file mode 100644 index 2f5eb4caec..0000000000 --- a/constraint/tinyfield/solution.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package cs - -import ( - "errors" - "fmt" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/debug" - "github.com/rs/zerolog" - "math/big" - "strconv" - "strings" - "sync/atomic" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i := 0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} diff --git a/constraint/tinyfield/solver.go b/constraint/tinyfield/solver.go new file mode 100644 index 0000000000..7cb146906c --- /dev/null +++ b/constraint/tinyfield/solver.go @@ -0,0 +1,647 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + + fr "github.com/consensys/gnark/internal/tinyfield" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/tinyfield/system.go b/constraint/tinyfield/system.go new file mode 100644 index 0000000000..f7cb51bbbf --- /dev/null +++ b/constraint/tinyfield/system.go @@ -0,0 +1,379 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/fxamacker/cbor/v2" + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/internal/backend/ioutils" + "github.com/consensys/gnark/logger" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + fr "github.com/consensys/gnark/internal/tinyfield" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.UNKNOWN +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/debug/debug.go b/debug/debug.go index f41bac9e83..32b891c38a 100644 --- a/debug/debug.go +++ b/debug/debug.go @@ -56,7 +56,7 @@ func writeStack(sbb *strings.Builder, forceClean ...bool) { if strings.Contains(frame.File, "test/engine.go") { continue } - if strings.Contains(frame.File, "gnark/frontend") { + if strings.Contains(frame.File, "gnark/frontend/cs") { continue } file = filepath.Base(file) @@ -75,5 +75,8 @@ func writeStack(sbb *strings.Builder, forceClean ...bool) { if strings.HasSuffix(function, "Define") { break } + if strings.HasSuffix(function, "callDeferred") { + break + } } } diff --git a/debug/symbol_table.go b/debug/symbol_table.go index a840aeb7ae..dc25c3e8d8 100644 --- a/debug/symbol_table.go +++ b/debug/symbol_table.go @@ -61,7 +61,7 @@ func (st *SymbolTable) CollectStack() []int { if strings.Contains(frame.File, "test/engine.go") { continue } - if strings.Contains(frame.File, "gnark/frontend") { + if strings.Contains(frame.File, "gnark/frontend/cs") { continue } frame.File = filepath.Base(frame.File) @@ -76,6 +76,9 @@ func (st *SymbolTable) CollectStack() []int { if strings.HasSuffix(function, "Define") { break } + if strings.HasSuffix(function, "callDeferred") { + break + } } return r } diff --git a/debug_test.go b/debug_test.go index d3efb39e1e..620e9860d0 100644 --- a/debug_test.go +++ b/debug_test.go @@ -9,6 +9,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" @@ -46,11 +48,11 @@ func TestPrintln(t *testing.T) { witness.B = 11 var expected bytes.Buffer - expected.WriteString("debug_test.go:28 > 13 is the addition\n") - expected.WriteString("debug_test.go:30 > 26 42\n") - expected.WriteString("debug_test.go:32 > bits 1\n") - expected.WriteString("debug_test.go:33 > circuit {A: 2, B: 11}\n") - expected.WriteString("debug_test.go:37 > m .*\n") + expected.WriteString("debug_test.go:30 > 13 is the addition\n") + expected.WriteString("debug_test.go:32 > 26 42\n") + expected.WriteString("debug_test.go:34 > bits 1\n") + expected.WriteString("debug_test.go:35 > circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:39 > m .*\n") { trace, _ := getGroth16Trace(&circuit, &witness) @@ -94,9 +96,14 @@ func TestTraceDivBy0(t *testing.T) { { _, err := getPlonkTrace(&circuit, &witness) assert.Error(err) - assert.Contains(err.Error(), "constraint #1 is not satisfied: [inverse] 1/0 < ∞") - assert.Contains(err.Error(), "(*divBy0Trace).Define") - assert.Contains(err.Error(), "debug_test.go:") + if debug.Debug { + assert.Contains(err.Error(), "constraint #1 is not satisfied: [inverse] 1/0 < ∞") + assert.Contains(err.Error(), "(*divBy0Trace).Define") + assert.Contains(err.Error(), "debug_test.go:") + } else { + assert.Contains(err.Error(), "constraint #1 is not satisfied: division by 0") + } + } } @@ -123,17 +130,26 @@ func TestTraceNotEqual(t *testing.T) { { _, err := getGroth16Trace(&circuit, &witness) assert.Error(err) - assert.Contains(err.Error(), "constraint #0 is not satisfied: [assertIsEqual] 1 == 66") - assert.Contains(err.Error(), "(*notEqualTrace).Define") - assert.Contains(err.Error(), "debug_test.go:") + if debug.Debug { + assert.Contains(err.Error(), "constraint #0 is not satisfied: [assertIsEqual] 1 == 66") + assert.Contains(err.Error(), "(*notEqualTrace).Define") + assert.Contains(err.Error(), "debug_test.go:") + } else { + assert.Contains(err.Error(), "constraint #0 is not satisfied: 1 ⋅ 1 != 66") + } } { _, err := getPlonkTrace(&circuit, &witness) assert.Error(err) - assert.Contains(err.Error(), "constraint #1 is not satisfied: [assertIsEqual] 1 + -66 == 0") - assert.Contains(err.Error(), "(*notEqualTrace).Define") - assert.Contains(err.Error(), "debug_test.go:") + if debug.Debug { + assert.Contains(err.Error(), "constraint #1 is not satisfied: [assertIsEqual] 1 == 66") + assert.Contains(err.Error(), "(*notEqualTrace).Define") + assert.Contains(err.Error(), "debug_test.go:") + } else { + assert.Contains(err.Error(), "constraint #1 is not satisfied: qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → 1 + -66 + 0 + 0 + 0 != 0") + } + } } @@ -158,7 +174,7 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { return "", err } log := zerolog.New(&zerolog.ConsoleWriter{Out: &buf, NoColor: true, PartsExclude: []string{zerolog.LevelFieldName, zerolog.TimestampFieldName}}) - _, err = plonk.Prove(ccs, pk, sw, backend.WithCircuitLogger(log)) + _, err = plonk.Prove(ccs, pk, sw, backend.WithSolverOptions(solver.WithLogger(log))) return buf.String(), err } @@ -179,6 +195,6 @@ func getGroth16Trace(circuit, w frontend.Circuit) (string, error) { return "", err } log := zerolog.New(&zerolog.ConsoleWriter{Out: &buf, NoColor: true, PartsExclude: []string{zerolog.LevelFieldName, zerolog.TimestampFieldName}}) - _, err = groth16.Prove(ccs, pk, sw, backend.WithCircuitLogger(log)) + _, err = groth16.Prove(ccs, pk, sw, backend.WithSolverOptions(solver.WithLogger(log))) return buf.String(), err } diff --git a/doc.go b/doc.go index 7c1b685665..9a1a7958a5 100644 --- a/doc.go +++ b/doc.go @@ -22,7 +22,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" ) -var Version = semver.MustParse("0.8.1") +var Version = semver.MustParse("0.9.0-alpha") // Curves return the curves supported by gnark func Curves() []ecc.ID { diff --git a/examples/plonk/main.go b/examples/plonk/main.go index 36e355ccde..80aeea5eaf 100644 --- a/examples/plonk/main.go +++ b/examples/plonk/main.go @@ -80,7 +80,7 @@ func main() { // create the necessary data for KZG. // This is a toy example, normally the trusted setup to build ZKG - // has been ran before. + // has been run before. // The size of the data in KZG should be the closest power of 2 bounding // // above max(nbConstraints, nbVariables). _r1cs := ccs.(*cs.SparseR1CS) diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 0b3b5000ba..27b47f4fda 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -20,7 +20,7 @@ import ( tedwards "github.com/consensys/gnark-crypto/ecc/twistededwards" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/accumulator/merkle" - "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/signature/eddsa" ) diff --git a/frontend/api.go b/frontend/api.go index 3645f04197..a54e2ec3c0 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -19,7 +19,7 @@ package frontend import ( "math/big" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" ) // API represents the available functions to circuit developers @@ -30,9 +30,16 @@ type API interface { // Add returns res = i1+i2+...in Add(i1, i2 Variable, in ...Variable) Variable - // MulAcc sets and return a = a + (b*c) - // ! may mutate a without allocating a new result - // ! always use MulAcc(...) result for correctness + // MulAcc sets and return a = a + (b*c). + // + // ! The method may mutate a without allocating a new result. If the input + // is used elsewhere, then first initialize new variable, for example by + // doing: + // + // acopy := api.Mul(a, 1) + // acopy = MulAcc(acopy, b, c) + // + // ! But it may not modify a, always use MulAcc(...) result for correctness. MulAcc(a, b, c Variable) Variable // Neg returns -i @@ -122,7 +129,7 @@ type API interface { // NewHint is a shortcut to api.Compiler().NewHint() // Deprecated: use api.Compiler().NewHint() instead - NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + NewHint(f solver.Hint, nbOutputs int, inputs ...Variable) ([]Variable, error) // ConstantValue is a shortcut to api.Compiler().ConstantValue() // Deprecated: use api.Compiler().ConstantValue() instead diff --git a/frontend/builder.go b/frontend/builder.go index c15dc1d334..717dbfca7a 100644 --- a/frontend/builder.go +++ b/frontend/builder.go @@ -3,8 +3,8 @@ package frontend import ( "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend/schema" ) @@ -12,6 +12,8 @@ type NewBuilder func(*big.Int, CompileConfig) (Builder, error) // Compiler represents a constraint system compiler type Compiler interface { + constraint.CustomizableSystem + // MarkBoolean sets (but do not constraint!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. @@ -36,7 +38,8 @@ type Compiler interface { // manually in the circuit. Failing to do so leads to solver failure. // // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs - NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + NewHint(f solver.Hint, nbOutputs int, inputs ...Variable) ([]Variable, error) + NewHintForId(id solver.HintID, nbOutputs int, inputs ...Variable) ([]Variable, error) // ConstantValue returns the big.Int value of v and true if op is a success. // nil and false if failure. This API returns a boolean to allow for future refactoring @@ -49,12 +52,20 @@ type Compiler interface { // FieldBitLen returns the number of bits needed to represent an element in the scalar field FieldBitLen() int - // Commit returns a commitment to the given variables, to be used as initial randomness in - // Fiat-Shamir when the statement to prove is particularly large. - // TODO cite paper - // ! Experimental - // TENTATIVE: Functions regarding fiat-shamir-ed proofs over enormous statements TODO finalize - Commit(...Variable) (Variable, error) + // Defer is called after circuit.Define() and before Compile(). This method + // allows for the circuits to register callbacks which finalize batching + // operations etc. Unlike Go defer, it is not locally scoped. + Defer(cb func(api API) error) + + // InternalVariable returns the internal variable associated with the given wireID + // ! Experimental: use in conjunction with constraint.CustomizableSystem + InternalVariable(wireID uint32) Variable + + // ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable + // ! Experimental: use in conjunction with constraint.CustomizableSystem + ToCanonicalVariable(Variable) CanonicalVariable + + SetGkrInfo(constraint.GkrInfo) error } // Builder represents a constraint system builder @@ -73,3 +84,27 @@ type Builder interface { // called inside circuit.Define() SecretVariable(schema.LeafInfo) Variable } + +// Committer allows to commit to the variables and returns the commitment. The +// commitment can be used as a challenge using Fiat-Shamir heuristic. +type Committer interface { + // Commit commits to the variables and returns the commitment. + Commit(toCommit ...Variable) (commitment Variable, err error) +} + +// Rangechecker allows to externally range-check the variables to be of +// specified width. Not all compilers implement this interface. Users should +// instead use [github.com/consensys/gnark/std/rangecheck] package which +// automatically chooses most optimal method for range checking the variables. +type Rangechecker interface { + // Check checks that the given variable v has bit-length bits. + Check(v Variable, bits int) +} + +// CanonicalVariable represents a variable that's encoded in a constraint system specific way. +// For example a R1CS builder may represent this as a constraint.LinearExpression, +// a PLONK builder --> constraint.Term +// and the test/Engine --> ~*big.Int. +type CanonicalVariable interface { + constraint.Compressible +} diff --git a/frontend/compile.go b/frontend/compile.go index d39ff5dd7b..072aca1c34 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -9,6 +9,7 @@ import ( "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/circuitdefer" "github.com/consensys/gnark/logger" ) @@ -36,7 +37,7 @@ func Compile(field *big.Int, newBuilder NewBuilder, circuit Circuit, opts ...Com log := logger.Logger() log.Info().Msg("compiling circuit") // parse options - opt := CompileConfig{} + opt := defaultCompileConfig() for _, o := range opts { if err := o(&opt); err != nil { log.Err(err).Msg("applying compile option") @@ -122,15 +123,33 @@ func parseCircuit(builder Builder, circuit Circuit) (err error) { if err = circuit.Define(builder); err != nil { return fmt.Errorf("define circuit: %w", err) } + if err = callDeferred(builder); err != nil { + return fmt.Errorf("deferred: %w", err) + } return } +func callDeferred(builder Builder) error { + for i := 0; i < len(circuitdefer.GetAll[func(API) error](builder)); i++ { + if err := circuitdefer.GetAll[func(API) error](builder)[i](builder); err != nil { + return fmt.Errorf("defer fn %d: %w", i, err) + } + } + return nil +} + // CompileOption defines option for altering the behaviour of the Compile // method. See the descriptions of the functions returning instances of this // type for available options. type CompileOption func(opt *CompileConfig) error +func defaultCompileConfig() CompileConfig { + return CompileConfig{ + CompressThreshold: 300, + } +} + type CompileConfig struct { Capacity int IgnoreUnconstrainedInputs bool @@ -174,6 +193,9 @@ func IgnoreUnconstrainedInputs() CompileOption { // fast. The compression adds some overhead in the number of constraints. The // overhead and compile performance depends on threshold value, and it should be // chosen carefully. +// +// If this option is not given then by default we use the compress threshold of +// 300. func WithCompressThreshold(threshold int) CompileOption { return func(opt *CompileConfig) error { opt.CompressThreshold = threshold diff --git a/frontend/cs/commitment.go b/frontend/cs/commitment.go new file mode 100644 index 0000000000..e8b92478f9 --- /dev/null +++ b/frontend/cs/commitment.go @@ -0,0 +1,52 @@ +package cs + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/logger" + "hash/fnv" + "math/big" + "os" + "strconv" + "strings" + "sync" +) + +func Bsb22CommitmentComputePlaceholder(mod *big.Int, _ []*big.Int, output []*big.Int) (err error) { + if (len(os.Args) > 0 && (strings.HasSuffix(os.Args[0], ".test") || strings.HasSuffix(os.Args[0], ".test.exe"))) || debug.Debug { + // usually we only run solver without prover during testing + log := logger.Logger() + log.Error().Msg("Augmented commitment hint not replaced. Proof will not be sound and verification will fail!") + output[0], err = rand.Int(rand.Reader, mod) + return + } + return fmt.Errorf("placeholder function: to be replaced by commitment computation") +} + +var maxNbCommitments int +var m sync.Mutex + +func RegisterBsb22CommitmentComputePlaceholder(index int) (hintId solver.HintID, err error) { + + hintName := "bsb22 commitment #" + strconv.Itoa(index) + // mimic solver.GetHintID + hf := fnv.New32a() + if _, err = hf.Write([]byte(hintName)); err != nil { + return + } + hintId = solver.HintID(hf.Sum32()) + + m.Lock() + if maxNbCommitments == index { + maxNbCommitments++ + solver.RegisterNamedHint(Bsb22CommitmentComputePlaceholder, hintId) + } + m.Unlock() + + return +} +func init() { + solver.RegisterHint(Bsb22CommitmentComputePlaceholder) +} diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 92f6281de9..797bba0854 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -19,14 +19,17 @@ package r1cs import ( "errors" "fmt" - "math/big" + "github.com/consensys/gnark/internal/utils" "path/filepath" "reflect" "runtime" "strings" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" @@ -55,14 +58,14 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { // v1 and v2 are both unknown, this is the only case we add a constraint if !v1Constant && !v2Constant { res := builder.newInternalVariable() - builder.cs.AddConstraint(builder.newR1C(b, c, res)) + builder.cs.AddR1C(builder.newR1C(b, c, res), builder.genericGate) builder.mbuf1 = append(builder.mbuf1, res...) return } // v1 and v2 are constants, we multiply big.Int values and return resulting constant if v1Constant && v2Constant { - builder.cs.Mul(&n1, &n2) + n1 = builder.cs.Mul(n1, n2) builder.mbuf1 = append(builder.mbuf1, expr.NewTerm(0, n1)) return } @@ -143,9 +146,9 @@ func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int if curr != -1 && t.VID == (*res)[curr].VID { // accumulate, it's the same variable ID if sub && lID != 0 { - builder.cs.Sub(&(*res)[curr].Coeff, &t.Coeff) + (*res)[curr].Coeff = builder.cs.Sub((*res)[curr].Coeff, t.Coeff) } else { - builder.cs.Add(&(*res)[curr].Coeff, &t.Coeff) + (*res)[curr].Coeff = builder.cs.Add((*res)[curr].Coeff, t.Coeff) } if (*res)[curr].Coeff.IsZero() { // remove self. @@ -157,14 +160,14 @@ func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int (*res) = append((*res), *t) curr++ if sub && lID != 0 { - builder.cs.Neg(&(*res)[curr].Coeff) + (*res)[curr].Coeff = builder.cs.Neg((*res)[curr].Coeff) } } } if len((*res)) == 0 { // keep the linear expression valid (assertIsSet) - (*res) = append((*res), expr.NewTerm(0, constraint.Coeff{})) + (*res) = append((*res), expr.NewTerm(0, constraint.Element{})) } // if the linear expression LE is too long then record an equality // constraint LE * 1 = t and return short linear expression instead. @@ -183,7 +186,7 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariable(i) if n, ok := builder.constantValue(v); ok { - builder.cs.Neg(&n) + n = builder.cs.Neg(n) return expr.NewLinearExpression(0, n) } @@ -202,13 +205,13 @@ func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f // v1 and v2 are both unknown, this is the only case we add a constraint if !v1Constant && !v2Constant { res := builder.newInternalVariable() - builder.cs.AddConstraint(builder.newR1C(v1, v2, res)) + builder.cs.AddR1C(builder.newR1C(v1, v2, res), builder.genericGate) return res } // v1 and v2 are constants, we multiply big.Int values and return resulting constant if v1Constant && v2Constant { - builder.cs.Mul(&n1, &n2) + n1 = builder.cs.Mul(n1, n2) return expr.NewLinearExpression(0, n1) } @@ -227,7 +230,7 @@ func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f return res } -func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint.Coeff, inPlace bool) expr.LinearExpression { +func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint.Element, inPlace bool) expr.LinearExpression { // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable var res expr.LinearExpression @@ -238,7 +241,7 @@ func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint. } for i := 0; i < len(res); i++ { - builder.cs.Mul(&res[i].Coeff, &lambda) + res[i].Coeff = builder.cs.Mul(res[i].Coeff, lambda) } return res } @@ -254,9 +257,12 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable if !v2Constant { res := builder.newInternalVariable() - debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) // note that here we don't ensure that divisor is != 0 - builder.cs.AddConstraint(builder.newR1C(v2, res, v1), debug) + cID := builder.cs.AddR1C(builder.newR1C(v2, res, v1), builder.genericGate) + if debug.Debug { + debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) + builder.cs.AttachDebugInfo(debug, []int{cID}) + } return res } @@ -264,12 +270,10 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable if n2.IsZero() { panic("div by constant(0)") } - // q := builder.q - builder.cs.Inverse(&n2) - // n2.ModInverse(n2, q) + n2, _ = builder.cs.Inverse(n2) if v1Constant { - builder.cs.Mul(&n2, &n1) + n2 = builder.cs.Mul(n2, n1) return expr.NewLinearExpression(0, n2) } @@ -292,8 +296,8 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) v2Inv := builder.newInternalVariable() // note that here we ensure that v2 can't be 0, but it costs us one extra constraint - c1 := builder.cs.AddConstraint(builder.newR1C(v2, v2Inv, builder.cstOne())) - c2 := builder.cs.AddConstraint(builder.newR1C(v1, v2Inv, res)) + c1 := builder.cs.AddR1C(builder.newR1C(v2, v2Inv, builder.cstOne()), builder.genericGate) + c2 := builder.cs.AddR1C(builder.newR1C(v1, v2Inv, res), builder.genericGate) builder.cs.AttachDebugInfo(debug, []int{c1, c2}) return res } @@ -302,10 +306,10 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { if n2.IsZero() { panic("div by constant(0)") } - builder.cs.Inverse(&n2) + n2, _ = builder.cs.Inverse(n2) if v1Constant { - builder.cs.Mul(&n2, &n1) + n2 = builder.cs.Mul(n2, n1) return expr.NewLinearExpression(0, n2) } @@ -322,15 +326,18 @@ func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { panic("inverse by constant(0)") } - builder.cs.Inverse(&c) + c, _ = builder.cs.Inverse(c) return expr.NewLinearExpression(0, c) } // allocate resulting frontend.Variable res := builder.newInternalVariable() - debug := builder.newDebugInfo("inverse", vars[0], "*", res, " == 1") - builder.cs.AddConstraint(builder.newR1C(res, vars[0], builder.cstOne()), debug) + cID := builder.cs.AddR1C(builder.newR1C(res, vars[0], builder.cstOne()), builder.genericGate) + if debug.Debug { + debug := builder.newDebugInfo("inverse", vars[0], "*", res, " == 1") + builder.cs.AttachDebugInfo(debug, []int{cID}) + } return res } @@ -406,7 +413,7 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { c = append(c, a...) c = append(c, b...) - builder.cs.AddConstraint(builder.newR1C(a, b, c)) + builder.cs.AddR1C(builder.newR1C(a, b, c), builder.genericGate) return res } @@ -441,7 +448,7 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { if c, ok := builder.constantValue(cond); ok { // condition is a constant return i1 if true, i2 if false - if builder.isCstOne(&c) { + if builder.isCstOne(c) { return vars[1] } return vars[2] @@ -451,7 +458,7 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { n2, ok2 := builder.constantValue(vars[2]) if ok1 && ok2 { - builder.cs.Sub(&n1, &n2) + n1 = builder.cs.Sub(n1, n2) res := builder.Mul(cond, n1) // no constraint is recorded res = builder.Add(res, vars[2]) // no constraint is recorded return res @@ -488,8 +495,8 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten c1, b1IsConstant := builder.constantValue(s1) if b0IsConstant && b1IsConstant { - b0 := builder.isCstOne(&c0) - b1 := builder.isCstOne(&c1) + b0 := builder.isCstOne(c0) + b1 := builder.isCstOne(c1) if !b0 && !b1 { return in0 @@ -543,17 +550,17 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { m := builder.newInternalVariable() // x = 1/a // in a hint (x == 0 if a == 0) - x, err := builder.NewHint(hint.InvZero, 1, a) + x, err := builder.NewHint(solver.InvZeroHint, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) } // m = -a*x + 1 // constrain m to be 1 if a == 0 - c1 := builder.cs.AddConstraint(builder.newR1C(builder.Neg(a), x[0], builder.Sub(m, 1))) + c1 := builder.cs.AddR1C(builder.newR1C(builder.Neg(a), x[0], builder.Sub(m, 1)), builder.genericGate) // a * m = 0 // constrain m to be 0 if a != 0 - c2 := builder.cs.AddConstraint(builder.newR1C(a, m, builder.cstZero())) + c2 := builder.cs.AddR1C(builder.newR1C(a, m, builder.cstZero()), builder.genericGate) builder.cs.AttachDebugInfo(debug, []int{c1, c2}) @@ -660,7 +667,7 @@ func (builder *builder) negateLinExp(l expr.LinearExpression) expr.LinearExpress res := make(expr.LinearExpression, len(l)) copy(res, l) for i := 0; i < len(res); i++ { - builder.cs.Neg(&res[i].Coeff) + res[i].Coeff = builder.cs.Neg(res[i].Coeff) } return res } @@ -670,23 +677,32 @@ func (builder *builder) Compiler() frontend.Compiler { } func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { - // we want to build a sorted slice of commited variables, without duplicates + + commitments := builder.cs.GetCommitments().(constraint.Groth16Commitments) + existingCommitmentIndexes := commitments.CommitmentIndexes() + privateCommittedSeeker := utils.MultiListSeeker(commitments.GetPrivateCommitted()) + + // we want to build a sorted slice of committed variables, without duplicates // this is the same algorithm as builder.add(...); but we expect len(v) to be quite large. vars, s := builder.toVariables(v...) + nbPublicCommitted := 0 // initialize the min-heap // this is the same algorithm as api.add --> we want to merge k sorted linear expression for lID, v := range vars { - builder.heap = append(builder.heap, linMeta{val: v[0].VID, lID: lID}) + if v[0].VID < builder.cs.GetNbPublicVariables() { + nbPublicCommitted++ + } + builder.heap = append(builder.heap, linMeta{val: v[0].VID, lID: lID}) // TODO: Use int heap } builder.heap.heapify() // sort all the wires - committed := make([]int, 0, s) - - curr := -1 - nbPublicCommitted := 0 + publicAndCommitmentCommitted := make([]int, 0, nbPublicCommitted+len(existingCommitmentIndexes)) // right now nbPublicCommitted is an upper bound + privateCommitted := make([]int, 0, s) + lastInsertedWireId := -1 + nbPublicCommitted = 0 // process all the terms from all the inputs, in sorted order for len(builder.heap) > 0 { @@ -704,43 +720,79 @@ func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error if t.VID == 0 { continue // don't commit to ONE_WIRE } - if curr != -1 && t.VID == committed[curr] { + if lastInsertedWireId == t.VID { // it's the same variable ID, do nothing continue - } else { - // append, it's a new variable ID - committed = append(committed, t.VID) - if t.VID < builder.cs.GetNbPublicVariables() { - nbPublicCommitted++ - } - curr++ } + + if t.VID < builder.cs.GetNbPublicVariables() { // public + publicAndCommitmentCommitted = append(publicAndCommitmentCommitted, t.VID) + lastInsertedWireId = t.VID + nbPublicCommitted++ + continue + } + + // private or commitment + for len(existingCommitmentIndexes) > 0 && existingCommitmentIndexes[0] < t.VID { + existingCommitmentIndexes = existingCommitmentIndexes[1:] + } + if len(existingCommitmentIndexes) > 0 && existingCommitmentIndexes[0] == t.VID { // commitment + publicAndCommitmentCommitted = append(publicAndCommitmentCommitted, t.VID) + existingCommitmentIndexes = existingCommitmentIndexes[1:] // technically unnecessary + lastInsertedWireId = t.VID + continue + } + + // private + // Cannot commit to a secret variable that has already been committed to + // instead we commit to its commitment + if committer := privateCommittedSeeker.Seek(t.VID); committer != -1 { + committerWireIndex := existingCommitmentIndexes[committer] // commit to this commitment instead + vars = append(vars, expr.LinearExpression{{Coeff: constraint.Element{1}, VID: committerWireIndex}}) // TODO Replace with mont 1 + builder.heap.push(linMeta{lID: len(vars) - 1, tID: 0, val: committerWireIndex}) // pushing to heap mid-op is okay because toCommit > t.VID > anything popped so far + continue + } + + // so it's a new, so-far-uncommitted private variable + privateCommitted = append(privateCommitted, t.VID) + lastInsertedWireId = t.VID } - if len(committed) == 0 { + if len(privateCommitted)+len(publicAndCommitmentCommitted) == 0 { // TODO @tabaie Necessary? return nil, errors.New("must commit to at least one variable") } // build commitment - commitment := constraint.NewCommitment(committed, nbPublicCommitted) + commitment := constraint.Groth16Commitment{ + PublicAndCommitmentCommitted: publicAndCommitmentCommitted, + NbPublicCommitted: nbPublicCommitted, + PrivateCommitted: privateCommitted, + } // hint is used at solving time to compute the actual value of the commitment // it is going to be dynamically replaced at solving time. - hintOut, err := builder.NewHint(bsb22CommitmentComputePlaceholder, 1, builder.getCommittedVariables(&commitment)...) + + var ( + hintOut []frontend.Variable + err error + ) + + commitment.HintID, err = cs.RegisterBsb22CommitmentComputePlaceholder(len(commitments)) if err != nil { return nil, err } + + if hintOut, err = builder.NewHintForId(commitment.HintID, 1, builder.wireIDsToVars( + commitment.PublicAndCommitmentCommitted, + commitment.PrivateCommitted, + )...); err != nil { + return nil, err + } + cVar := hintOut[0] - commitment.HintID = hint.UUID(bsb22CommitmentComputePlaceholder) // TODO @gbotrel probably not needed commitment.CommitmentIndex = (cVar.(expr.LinearExpression))[0].WireID() - // TODO @Tabaie: Get rid of this field - commitment.CommittedAndCommitment = append(commitment.Committed, commitment.CommitmentIndex) - if commitment.CommitmentIndex <= commitment.Committed[len(commitment.Committed)-1] { - return nil, fmt.Errorf("commitment variable index smaller than some committed variable indices") - } - if err := builder.cs.AddCommitment(commitment); err != nil { return nil, err } @@ -748,14 +800,22 @@ func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error return cVar, nil } -func (builder *builder) getCommittedVariables(i *constraint.Commitment) []frontend.Variable { - res := make([]frontend.Variable, len(i.Committed)) - for j, wireIndex := range i.Committed { - res[j] = expr.NewLinearExpression(wireIndex, builder.tOne) +func (builder *builder) wireIDsToVars(wireIDs ...[]int) []frontend.Variable { + n := 0 + for i := range wireIDs { + n += len(wireIDs[i]) + } + res := make([]frontend.Variable, n) + n = 0 + for _, list := range wireIDs { + for i := range list { + res[n+i] = expr.NewLinearExpression(list[i], builder.tOne) + } + n += len(list) } return res } -func bsb22CommitmentComputePlaceholder(*big.Int, []*big.Int, []*big.Int) error { - return fmt.Errorf("placeholder function: to be replaced by commitment computation") +func (builder *builder) SetGkrInfo(info constraint.GkrInfo) error { + return builder.cs.AddGkr(info) } diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index 2d2db8f010..3c249f4845 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -17,12 +17,12 @@ limitations under the License. package r1cs import ( + "fmt" "math/big" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/math/bits" ) @@ -32,14 +32,22 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { r := builder.getLinearExpression(builder.toVariable(i1)) o := builder.getLinearExpression(builder.toVariable(i2)) - debug := builder.newDebugInfo("assertIsEqual", r, " == ", o) + cID := builder.cs.AddR1C(builder.newR1C(builder.cstOne(), r, o), builder.genericGate) - builder.cs.AddConstraint(builder.newR1C(builder.cstOne(), r, o), debug) + if debug.Debug { + debug := builder.newDebugInfo("assertIsEqual", r, " == ", o) + builder.cs.AttachDebugInfo(debug, []int{cID}) + } } // AssertIsDifferent constrain i1 and i2 to be different func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { - builder.Inverse(builder.Sub(i1, i2)) + s := builder.Sub(i1, i2).(expr.LinearExpression) + if len(s) == 1 && s[0].Coeff.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } + + builder.Inverse(s) } // AssertIsBoolean adds an assertion in the constraint builder (v == 0 ∥ v == 1) @@ -48,7 +56,7 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { v := builder.toVariable(i1) if b, ok := builder.constantValue(v); ok { - if !(builder.isCstZero(&b) || builder.isCstOne(&b)) { + if !(b.IsZero() || builder.isCstOne(b)) { panic("assertIsBoolean failed: constant is not 0 or 1") // TODO @gbotrel print } return @@ -66,11 +74,10 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { V := builder.getLinearExpression(v) + cID := builder.cs.AddR1C(builder.newR1C(V, _v, o), builder.genericGate) if debug.Debug { debug := builder.newDebugInfo("assertIsBoolean", V, " == (0|1)") - builder.cs.AddConstraint(builder.newR1C(V, _v, o), debug) - } else { - builder.cs.AddConstraint(builder.newR1C(V, _v, o)) + builder.cs.AttachDebugInfo(debug, []int{cID}) } } @@ -80,18 +87,34 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { // // derived from: // https://github.com/zcash/zips/blob/main/protocol/protocol.pdf -func (builder *builder) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { - v := builder.toVariable(_v) - - if b, ok := bound.(expr.LinearExpression); ok { - assertIsSet(b) - builder.mustBeLessOrEqVar(v, b) - } else { - builder.mustBeLessOrEqCst(v, utils.FromInterface(bound)) +func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { + cv, vConst := builder.constantValue(v) + cb, bConst := builder.constantValue(bound) + + // both inputs are constants + if vConst && bConst { + bv, bb := builder.cs.ToBigInt(cv), builder.cs.ToBigInt(cb) + if bv.Cmp(bb) == 1 { + panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", bv.String(), bb.String())) + } } + + // bound is constant + if bConst { + vv := builder.toVariable(v) + builder.mustBeLessOrEqCst(vv, builder.cs.ToBigInt(cb)) + return + } + + builder.mustBeLessOrEqVar(v, bound) } -func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) { +func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { + // here bound is NOT a constant, + // but a can be either constant or a wire. + + _, aConst := builder.constantValue(a) + debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) nbBits := builder.cs.FieldBitLen() @@ -128,16 +151,22 @@ func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i].(expr.LinearExpression)) // this does not create a constraint - added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], zero))) + if aConst { + // aBits[i] is a constant; + l = builder.Mul(l, aBits[i]) + // TODO @gbotrel this constraint seems useless. + added = append(added, builder.cs.AddR1C(builder.newR1C(l, zero, zero), builder.genericGate)) + } else { + added = append(added, builder.cs.AddR1C(builder.newR1C(l, aBits[i], zero), builder.genericGate)) + } } builder.cs.AttachDebugInfo(debug, added) } -func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound big.Int) { +func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.Int) { nbBits := builder.cs.FieldBitLen() @@ -186,8 +215,7 @@ func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound big.Int l := builder.Sub(1, p[i+1]) l = builder.Sub(l, aBits[i]) - added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], builder.cstZero()))) - builder.MarkBoolean(aBits[i].(expr.LinearExpression)) + added = append(added, builder.cs.AddR1C(builder.newR1C(l, aBits[i], builder.cstZero()), builder.genericGate)) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index d06bd9c3d9..1a27261bda 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -23,12 +23,14 @@ import ( "sort" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/frontendtype" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -40,29 +42,36 @@ import ( bn254r1cs "github.com/consensys/gnark/constraint/bn254" bw6633r1cs "github.com/consensys/gnark/constraint/bw6-633" bw6761r1cs "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/constraint/solver" tinyfieldr1cs "github.com/consensys/gnark/constraint/tinyfield" ) // NewBuilder returns a new R1CS builder which implements frontend.API. +// Additionally, this builder also implements [frontend.Committer]. func NewBuilder(field *big.Int, config frontend.CompileConfig) (frontend.Builder, error) { return newBuilder(field, config), nil } type builder struct { - cs constraint.R1CS - + cs constraint.R1CS config frontend.CompileConfig + kvstore.Store // map for recording boolean constrained variables (to not constrain them twice) mtBooleans map[uint64][]expr.LinearExpression - q *big.Int - tOne constraint.Coeff - heap minHeap // helps merge k sorted linear expressions + tOne constraint.Element + eZero, eOne expr.LinearExpression + cZero, cOne constraint.LinearExpression + + // helps merge k sorted linear expressions + heap minHeap // buffers used to do in place api.MAC mbuf1 expr.LinearExpression mbuf2 expr.LinearExpression + + genericGate constraint.BlueprintID } // initialCapacity has quite some impact on frontend performance, especially on large circuits size @@ -78,6 +87,7 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { heap: make(minHeap, 0, 100), mbuf1: make(expr.LinearExpression, 0, macCapacity), mbuf2: make(expr.LinearExpression, 0, macCapacity), + Store: kvstore.New(), } // by default the circuit is given a public wire equal to 1 @@ -110,10 +120,13 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { builder.tOne = builder.cs.One() builder.cs.AddPublicVariable("1") - builder.q = builder.cs.Field() - if builder.q.Cmp(field) != 0 { - panic("invalid modulus on cs impl") // sanity check - } + builder.genericGate = builder.cs.AddBlueprint(&constraint.BlueprintGenericR1C{}) + + builder.eZero = expr.NewLinearExpression(0, constraint.Element{}) + builder.eOne = expr.NewLinearExpression(0, builder.tOne) + + builder.cOne = constraint.LinearExpression{constraint.Term{VID: 0, CID: constraint.CoeffIdOne}} + builder.cZero = constraint.LinearExpression{constraint.Term{VID: 0, CID: constraint.CoeffIdZero}} return &builder } @@ -139,19 +152,15 @@ func (builder *builder) SecretVariable(f schema.LeafInfo) frontend.Variable { // cstOne return the one constant func (builder *builder) cstOne() expr.LinearExpression { - return expr.NewLinearExpression(0, builder.tOne) + return builder.eOne } // cstZero return the zero constant func (builder *builder) cstZero() expr.LinearExpression { - return expr.NewLinearExpression(0, constraint.Coeff{}) -} - -func (builder *builder) isCstZero(c *constraint.Coeff) bool { - return c.IsZero() + return builder.eZero } -func (builder *builder) isCstOne(c *constraint.Coeff) bool { +func (builder *builder) isCstOne(c constraint.Element) bool { return builder.cs.IsOne(c) } @@ -188,9 +197,16 @@ func (builder *builder) getLinearExpression(_l interface{}) constraint.LinearExp var L constraint.LinearExpression switch tl := _l.(type) { case expr.LinearExpression: + if len(tl) == 1 && tl[0].VID == 0 { + if tl[0].Coeff.IsZero() { + return builder.cZero + } else if tl[0].Coeff == builder.tOne { + return builder.cOne + } + } L = make(constraint.LinearExpression, 0, len(tl)) for _, t := range tl { - L = append(L, builder.cs.MakeTerm(&t.Coeff, t.VID)) + L = append(L, builder.cs.MakeTerm(t.Coeff, t.VID)) } case constraint.LinearExpression: L = tl @@ -208,7 +224,7 @@ func (builder *builder) getLinearExpression(_l interface{}) constraint.LinearExp // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. func (builder *builder) MarkBoolean(v frontend.Variable) { if b, ok := builder.constantValue(v); ok { - if !(builder.isCstZero(&b) || builder.isCstOne(&b)) { + if !(b.IsZero() || builder.isCstOne(b)) { panic("MarkBoolean called a non-boolean constant") } return @@ -228,7 +244,7 @@ func (builder *builder) MarkBoolean(v frontend.Variable) { // This returns true if the v is a constant and v == 0 || v == 1. func (builder *builder) IsBoolean(v frontend.Variable) bool { if b, ok := builder.constantValue(v); ok { - return (builder.isCstZero(&b) || builder.isCstOne(&b)) + return (b.IsZero() || builder.isCstOne(b)) } // v is a linear expression l := v.(expr.LinearExpression) @@ -280,20 +296,20 @@ func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { if !ok { return nil, false } - return builder.cs.ToBigInt(&coeff), true + return builder.cs.ToBigInt(coeff), true } -func (builder *builder) constantValue(v frontend.Variable) (constraint.Coeff, bool) { +func (builder *builder) constantValue(v frontend.Variable) (constraint.Element, bool) { if _v, ok := v.(expr.LinearExpression); ok { assertIsSet(_v) if len(_v) != 1 { // TODO @gbotrel this assumes linear expressions of coeff are not possible // and are always reduced to one element. may not always be true? - return constraint.Coeff{}, false + return constraint.Element{}, false } if !(_v[0].WireID() == 0) { // public ONE WIRE - return constraint.Coeff{}, false + return constraint.Element{}, false } return _v[0].Coeff, true } @@ -314,9 +330,9 @@ func (builder *builder) toVariable(input interface{}) expr.LinearExpression { case *expr.LinearExpression: assertIsSet(*t) return *t - case constraint.Coeff: + case constraint.Element: return expr.NewLinearExpression(0, t) - case *constraint.Coeff: + case *constraint.Element: return expr.NewLinearExpression(0, *t) default: // try to make it into a constant @@ -354,7 +370,15 @@ func (builder *builder) toVariables(in ...frontend.Variable) ([]expr.LinearExpre // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (builder *builder) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(f, solver.GetHintID(f), nbOutputs, inputs) +} + +func (builder *builder) NewHintForId(id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(nil, id, nbOutputs, inputs) +} + +func (builder *builder) newHint(f solver.Hint, id solver.HintID, nbOutputs int, inputs []frontend.Variable) ([]frontend.Variable, error) { hintInputs := make([]constraint.LinearExpression, len(inputs)) // TODO @gbotrel hint input pass @@ -365,13 +389,13 @@ func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...fronte hintInputs[i] = builder.getLinearExpression(t) } else { c := builder.cs.FromInterface(in) - term := builder.cs.MakeTerm(&c, 0) + term := builder.cs.MakeTerm(c, 0) term.MarkConstant() hintInputs[i] = constraint.LinearExpression{term} } } - internalVariables, err := builder.cs.AddSolverHint(f, hintInputs, nbOutputs) + internalVariables, err := builder.cs.AddSolverHint(f, id, hintInputs, nbOutputs) if err != nil { return nil, err } @@ -382,7 +406,6 @@ func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...fronte res[i] = expr.NewLinearExpression(idx, builder.tOne) } return res, nil - } // assertIsSet panics if the variable is unset @@ -423,10 +446,10 @@ func (builder *builder) newDebugInfo(errName string, in ...interface{}) constrai in[i] = builder.getLinearExpression(expr.LinearExpression{t}) case *expr.Term: in[i] = builder.getLinearExpression(expr.LinearExpression{*t}) - case constraint.Coeff: - in[i] = builder.cs.String(&t) - case *constraint.Coeff: + case constraint.Element: in[i] = builder.cs.String(t) + case *constraint.Element: + in[i] = builder.cs.String(*t) } } @@ -445,6 +468,42 @@ func (builder *builder) compress(le expr.LinearExpression) expr.LinearExpression one := builder.cstOne() t := builder.newInternalVariable() - builder.cs.AddConstraint(builder.newR1C(le, one, t)) + builder.cs.AddR1C(builder.newR1C(le, one, t), builder.genericGate) return t } + +func (builder *builder) Defer(cb func(frontend.API) error) { + circuitdefer.Put(builder, cb) +} + +func (*builder) FrontendType() frontendtype.Type { + return frontendtype.R1CS +} + +// AddInstruction is used to add custom instructions to the constraint system. +func (builder *builder) AddInstruction(bID constraint.BlueprintID, calldata []uint32) []uint32 { + return builder.cs.AddInstruction(bID, calldata) +} + +// AddBlueprint adds a custom blueprint to the constraint system. +func (builder *builder) AddBlueprint(b constraint.Blueprint) constraint.BlueprintID { + return builder.cs.AddBlueprint(b) +} + +func (builder *builder) InternalVariable(wireID uint32) frontend.Variable { + return expr.NewLinearExpression(int(wireID), builder.tOne) +} + +// ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable +// ! Experimental: use in conjunction with constraint.CustomizableSystem +func (builder *builder) ToCanonicalVariable(in frontend.Variable) frontend.CanonicalVariable { + if t, ok := in.(expr.LinearExpression); ok { + assertIsSet(t) + return builder.getLinearExpression(t) + } else { + c := builder.cs.FromInterface(in) + term := builder.cs.MakeTerm(c, 0) + term.MarkConstant() + return constraint.LinearExpression{term} + } +} diff --git a/frontend/cs/r1cs/heap.go b/frontend/cs/r1cs/heap.go index a0d6035411..07c6bc436e 100644 --- a/frontend/cs/r1cs/heap.go +++ b/frontend/cs/r1cs/heap.go @@ -21,7 +21,7 @@ func (h *minHeap) heapify() { } } -// push pushes the element x onto the heap. +// push the element x onto the heap. // The complexity is O(log n) where n = len(*h). func (h *minHeap) push(x linMeta) { *h = append(*h, x) diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index 51370f9802..ceb965322c 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -99,7 +99,7 @@ func BenchmarkReduce(b *testing.B) { // Add few large linear expressions // Add many large linear expressions // Doubling of large linear expressions - rand.Seed(time.Now().Unix()) + rand := rand.New(rand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here const nbTerms = 100000 terms := make([]frontend.Variable, nbTerms) for i := 0; i < len(terms); i++ { @@ -135,3 +135,28 @@ func BenchmarkReduce(b *testing.B) { } }) } + +type EmptyCircuit struct { + X frontend.Variable + cb func(frontend.API) error +} + +func (c *EmptyCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.X, 0) + api.Compiler().Defer(c.cb) + return nil +} + +func TestPreCompileHook(t *testing.T) { + var called bool + c := &EmptyCircuit{ + cb: func(a frontend.API) error { called = true; return nil }, + } + _, err := frontend.Compile(ecc.BN254.ScalarField(), NewBuilder, c) + if err != nil { + t.Fatal(err) + } + if !called { + t.Error("callback not called") + } +} diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 0971715911..75764878d3 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -18,49 +18,48 @@ package scs import ( "fmt" - "math/big" "path/filepath" "reflect" "runtime" "strings" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/frontendtype" "github.com/consensys/gnark/std/math/bits" ) // Add returns res = i1+i2+...in -func (builder *scs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - zero := big.NewInt(0) +func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + // separate the constant part from the variables vars, k := builder.filterConstantSum(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { - return k + // no variables, we return the constant. + return builder.cs.ToBigInt(k) } + vars = builder.reduce(vars) - if k.Cmp(zero) == 0 { - return builder.splitSum(vars[0], vars[1:]) + if k.IsZero() { + return builder.splitSum(vars[0], vars[1:], nil) } - cl, _ := vars[0].Unpack() - kID := builder.st.CoeffID(&k) - o := builder.newInternalVariable() - builder.addPlonkConstraint(vars[0], builder.zero(), o, cl, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, kID) - return builder.splitSum(o, vars[1:]) - + // no constant we decompose the linear expressions in additions of 2 terms + return builder.splitSum(vars[0], vars[1:], &k) } -func (builder *scs) MulAcc(a, b, c frontend.Variable) frontend.Variable { +func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { // TODO can we do better here to limit allocations? - // technically we could do that in one PlonK constraint (against 2 for separate Add & Mul) return builder.Add(a, builder.Mul(b, c)) } // neg returns -in -func (builder *scs) neg(in []frontend.Variable) []frontend.Variable { - +func (builder *builder) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) for i := 0; i < len(in); i++ { @@ -70,87 +69,79 @@ func (builder *scs) neg(in []frontend.Variable) []frontend.Variable { } // Sub returns res = i1 - i2 - ...in -func (builder *scs) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { r := builder.neg(append([]frontend.Variable{i2}, in...)) return builder.Add(i1, r[0], r[1:]...) } // Neg returns -i -func (builder *scs) Neg(i1 frontend.Variable) frontend.Variable { - if n, ok := builder.ConstantValue(i1); ok { - n.Neg(n) - return *n - } else { - v := i1.(expr.TermToRefactor) - c, _ := v.Unpack() - var coef big.Int - coef.Set(&builder.st.Coeffs[c]) - coef.Neg(&coef) - c = builder.st.CoeffID(&coef) - v.SetCoeffID(c) - return v +func (builder *builder) Neg(i1 frontend.Variable) frontend.Variable { + if n, ok := builder.constantValue(i1); ok { + n = builder.cs.Neg(n) + return builder.cs.ToBigInt(n) } + v := i1.(expr.Term) + v.Coeff = builder.cs.Neg(v.Coeff) + return v } // Mul returns res = i1 * i2 * ... in -func (builder *scs) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - +func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, k := builder.filterConstantProd(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { - return k + return builder.cs.ToBigInt(k) } - l := builder.mulConstant(vars[0], &k) - return builder.splitProd(l, vars[1:]) + l := builder.mulConstant(vars[0], k) + return builder.splitProd(l, vars[1:]) } // returns t*m -func (builder *scs) mulConstant(t expr.TermToRefactor, m *big.Int) expr.TermToRefactor { - var coef big.Int - cid, _ := t.Unpack() - coef.Set(&builder.st.Coeffs[cid]) - coef.Mul(m, &coef).Mod(&coef, builder.q) - cid = builder.st.CoeffID(&coef) - t.SetCoeffID(cid) +func (builder *builder) mulConstant(t expr.Term, m constraint.Element) expr.Term { + t.Coeff = builder.cs.Mul(t.Coeff, m) return t } // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 -func (builder *scs) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - c1, i1Constant := builder.ConstantValue(i1) - c2, i2Constant := builder.ConstantValue(i2) +func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { + c1, i1Constant := builder.constantValue(i1) + c2, i2Constant := builder.constantValue(i2) if i1Constant && i2Constant { - l := c1 - r := c2 - q := builder.q - return r.ModInverse(r, q). - Mul(l, r). - Mod(r, q) + if c2.IsZero() { + panic("inverse by constant(0)") + } + c2, _ = builder.cs.Inverse(c2) + c2 = builder.cs.Mul(c2, c1) + return builder.cs.ToBigInt(c2) } if i2Constant { - c := c2 - q := builder.q - c.ModInverse(c, q) - return builder.mulConstant(i1.(expr.TermToRefactor), c) + if c2.IsZero() { + panic("inverse by constant(0)") + } + c2, _ = builder.cs.Inverse(c2) + return builder.mulConstant(i1.(expr.Term), c2) } if i1Constant { res := builder.Inverse(i2) - return builder.mulConstant(res.(expr.TermToRefactor), c1) + return builder.mulConstant(res.(expr.Term), c1) } + // res * i2 == i1 res := builder.newInternalVariable() - r := i2.(expr.TermToRefactor) - o := builder.Neg(i1).(expr.TermToRefactor) - cr, _ := r.Unpack() - co, _ := o.Unpack() - builder.addPlonkConstraint(res, r, o, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, cr, co, constraint.CoeffIdZero) + builder.addPlonkConstraint(sparseR1C{ + xa: res.VID, + xb: i2.(expr.Term).VID, + xc: i1.(expr.Term).VID, + qM: i2.(expr.Term).Coeff, + qO: builder.cs.Neg(i1.(expr.Term).Coeff), + }) + return res } // Div returns i1 / i2 -func (builder *scs) Div(i1, i2 frontend.Variable) frontend.Variable { - +func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { // note that here we ensure that v2 can't be 0, but it costs us one extra constraint builder.Inverse(i2) @@ -158,16 +149,32 @@ func (builder *scs) Div(i1, i2 frontend.Variable) frontend.Variable { } // Inverse returns res = 1 / i1 -func (builder *scs) Inverse(i1 frontend.Variable) frontend.Variable { - if c, ok := builder.ConstantValue(i1); ok { - c.ModInverse(c, builder.q) - return c +func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { + if c, ok := builder.constantValue(i1); ok { + if c.IsZero() { + panic("inverse by constant(0)") + } + c, _ = builder.cs.Inverse(c) + return builder.cs.ToBigInt(c) } - t := i1.(expr.TermToRefactor) - cr, _ := t.Unpack() - debug := builder.newDebugInfo("inverse", "1/", i1, " < ∞") + t := i1.(expr.Term) res := builder.newInternalVariable() - builder.addPlonkConstraint(res, t, builder.zero(), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, cr, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, debug) + + // res * i1 - 1 == 0 + constraint := sparseR1C{ + xa: res.VID, + xb: t.VID, + qM: t.Coeff, + qC: builder.tMinusOne, + } + + if debug.Debug { + debug := builder.newDebugInfo("inverse", "1/", i1, " < ∞") + builder.addPlonkConstraint(constraint, debug) + } else { + builder.addPlonkConstraint(constraint) + } + return res } @@ -179,7 +186,7 @@ func (builder *scs) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (builder *scs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := builder.cs.FieldBitLen() if len(n) == 1 { @@ -193,85 +200,151 @@ func (builder *scs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable } // FromBinary packs b, seen as a fr.Element in little endian -func (builder *scs) FromBinary(b ...frontend.Variable) frontend.Variable { +func (builder *builder) FromBinary(b ...frontend.Variable) frontend.Variable { return bits.FromBinary(builder, b) } // Xor returns a ^ b // a and b must be 0 or 1 -func (builder *scs) Xor(a, b frontend.Variable) frontend.Variable { - +func (builder *builder) Xor(a, b frontend.Variable) frontend.Variable { + // pre condition: a, b must be booleans builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) - _a, aConstant := builder.ConstantValue(a) - _b, bConstant := builder.ConstantValue(b) + _a, aConstant := builder.constantValue(a) + _b, bConstant := builder.constantValue(b) + + // if both inputs are constants if aConstant && bConstant { - _a.Xor(_a, _b) - return _a + b0 := 0 + b1 := 0 + if builder.cs.IsOne(_a) { + b0 = 1 + } + if builder.cs.IsOne(_b) { + b1 = 1 + } + return b0 ^ b1 } res := builder.newInternalVariable() builder.MarkBoolean(res) + + // if one input is constant, ensure we put it in b. if aConstant { a, b = b, a bConstant = aConstant _b = _a } if bConstant { - l := a.(expr.TermToRefactor) - r := l - oneMinusTwoB := big.NewInt(1) - oneMinusTwoB.Sub(oneMinusTwoB, _b).Sub(oneMinusTwoB, _b) - builder.addPlonkConstraint(l, r, res, builder.st.CoeffID(oneMinusTwoB), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, builder.st.CoeffID(_b)) + xa := a.(expr.Term) + // 1 - 2b + qL := builder.tOne + qL = builder.cs.Sub(qL, _b) + qL = builder.cs.Sub(qL, _b) + qL = builder.cs.Mul(qL, xa.Coeff) + + // (1-2b)a + b == res + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xc: res.VID, + qL: qL, + qO: builder.tMinusOne, + qC: _b, + }) + // builder.addPlonkConstraint(xa, xb, res, builder.st.CoeffID(oneMinusTwoB), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, builder.st.CoeffID(_b)) return res } - l := a.(expr.TermToRefactor) - r := b.(expr.TermToRefactor) - builder.addPlonkConstraint(l, r, res, constraint.CoeffIdMinusOne, constraint.CoeffIdMinusOne, constraint.CoeffIdTwo, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero) + xa := a.(expr.Term) + xb := b.(expr.Term) + + // -a - b + 2ab + res == 0 + qM := builder.tOne + qM = builder.cs.Add(qM, qM) + qM = builder.cs.Mul(qM, xa.Coeff) + qM = builder.cs.Mul(qM, xb.Coeff) + + qL := builder.cs.Neg(xa.Coeff) + qR := builder.cs.Neg(xb.Coeff) + + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xb: xb.VID, + xc: res.VID, + qL: qL, + qR: qR, + qO: builder.tOne, + qM: qM, + }) + // builder.addPlonkConstraint(xa, xb, res, constraint.CoeffIdMinusOne, constraint.CoeffIdMinusOne, constraint.CoeffIdTwo, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero) return res } // Or returns a | b // a and b must be 0 or 1 -func (builder *scs) Or(a, b frontend.Variable) frontend.Variable { - +func (builder *builder) Or(a, b frontend.Variable) frontend.Variable { builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) - _a, aConstant := builder.ConstantValue(a) - _b, bConstant := builder.ConstantValue(b) + _a, aConstant := builder.constantValue(a) + _b, bConstant := builder.constantValue(b) if aConstant && bConstant { - _a.Or(_a, _b) - return _a + if builder.cs.IsOne(_a) || builder.cs.IsOne(_b) { + return 1 + } + return 0 } + res := builder.newInternalVariable() builder.MarkBoolean(res) + + // if one input is constant, ensure we put it in b if aConstant { a, b = b, a _b = _a bConstant = aConstant } - if bConstant { - l := a.(expr.TermToRefactor) - r := l - one := big.NewInt(1) - _b.Sub(_b, one) - idl := builder.st.CoeffID(_b) - builder.addPlonkConstraint(l, r, res, idl, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, constraint.CoeffIdZero) + if bConstant { + xa := a.(expr.Term) + // b = b - 1 + qL := _b + qL = builder.cs.Sub(qL, builder.tOne) + qL = builder.cs.Mul(qL, xa.Coeff) + // a * (b-1) + res == 0 + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xc: res.VID, + qL: qL, + qO: builder.tOne, + }) return res } - l := a.(expr.TermToRefactor) - r := b.(expr.TermToRefactor) - builder.addPlonkConstraint(l, r, res, constraint.CoeffIdMinusOne, constraint.CoeffIdMinusOne, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero) + xa := a.(expr.Term) + xb := b.(expr.Term) + // -a - b + ab + res == 0 + + qM := builder.cs.Mul(xa.Coeff, xb.Coeff) + + qL := builder.cs.Neg(xa.Coeff) + qR := builder.cs.Neg(xb.Coeff) + + builder.addPlonkConstraint(sparseR1C{ + xa: xa.VID, + xb: xb.VID, + xc: res.VID, + qL: qL, + qR: qR, + qM: qM, + qO: builder.tOne, + }) return res } // Or returns a & b // a and b must be 0 or 1 -func (builder *scs) And(a, b frontend.Variable) frontend.Variable { +func (builder *builder) And(a, b frontend.Variable) frontend.Variable { builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) res := builder.Mul(a, b) @@ -283,14 +356,14 @@ func (builder *scs) And(a, b frontend.Variable) frontend.Variable { // Conditionals // Select if b is true, yields i1 else yields i2 -func (builder *scs) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { - _b, bConstant := builder.ConstantValue(b) +func (builder *builder) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { + _b, bConstant := builder.constantValue(b) if bConstant { - if !(_b.IsUint64() && (_b.Uint64() <= 1)) { - panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) + if !builder.IsBoolean(b) { + panic(fmt.Sprintf("%s should be 0 or 1", builder.cs.String(_b))) } - if _b.Uint64() == 0 { + if _b.IsZero() { return i2 } return i1 @@ -305,23 +378,18 @@ func (builder *scs) Select(b frontend.Variable, i1, i2 frontend.Variable) fronte // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (builder *scs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { - - // vars, _ := builder.toVariables(b0, b1, i0, i1, i2, i3) - // s0, s1 := vars[0], vars[1] - // in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] - +func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { // ensure that bits are actually bits. Adds no constraints if the variables // are already constrained. builder.AssertIsBoolean(b0) builder.AssertIsBoolean(b1) - c0, b0IsConstant := builder.ConstantValue(b0) - c1, b1IsConstant := builder.ConstantValue(b1) + c0, b0IsConstant := builder.constantValue(b0) + c1, b1IsConstant := builder.constantValue(b1) if b0IsConstant && b1IsConstant { - b0 := c0.Uint64() == 1 - b1 := c1.Uint64() == 1 + b0 := builder.cs.IsOne(c0) + b1 := builder.cs.IsOne(c1) if !b0 && !b1 { return i0 @@ -360,22 +428,22 @@ func (builder *scs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Va } // IsZero returns 1 if a is zero, 0 otherwise -func (builder *scs) IsZero(i1 frontend.Variable) frontend.Variable { - if a, ok := builder.ConstantValue(i1); ok { - if !(a.IsUint64() && a.Uint64() == 0) { - return 0 +func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { + if a, ok := builder.constantValue(i1); ok { + if a.IsZero() { + return 1 } - return 1 + return 0 } // x = 1/a // in a hint (x == 0 if a == 0) // m = -a*x + 1 // constrain m to be 1 if a == 0 // a * m = 0 // constrain m to be 0 if a != 0 - a := i1.(expr.TermToRefactor) + a := i1.(expr.Term) m := builder.newInternalVariable() // x = 1/a // in a hint (x == 0 if a == 0) - x, err := builder.NewHint(hint.InvZero, 1, a) + x, err := builder.NewHint(solver.InvZeroHint, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) @@ -383,24 +451,28 @@ func (builder *scs) IsZero(i1 frontend.Variable) frontend.Variable { // m = -a*x + 1 // constrain m to be 1 if a == 0 // a*x + m - 1 == 0 - builder.addPlonkConstraint(a, - x[0].(expr.TermToRefactor), - m, - constraint.CoeffIdZero, - constraint.CoeffIdZero, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdMinusOne) + X := x[0].(expr.Term) + builder.addPlonkConstraint(sparseR1C{ + xa: a.VID, + xb: X.VID, + xc: m.VID, + qM: a.Coeff, + qO: builder.tOne, + qC: builder.tMinusOne, + }) // a * m = 0 // constrain m to be 0 if a != 0 - builder.addPlonkConstraint(a, m, builder.zero(), constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdOne, constraint.CoeffIdOne, constraint.CoeffIdZero, constraint.CoeffIdZero) + builder.addPlonkConstraint(sparseR1C{ + xa: a.VID, + xb: m.VID, + qM: a.Coeff, + }) return m } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1= 0; i-- { - iszeroi1 := builder.IsZero(bi1[i]) iszeroi2 := builder.IsZero(bi2[i]) @@ -420,7 +491,6 @@ func (builder *scs) Cmp(i1, i2 frontend.Variable) frontend.Variable { m := builder.Select(i1i2, 1, n) res = builder.Select(builder.IsZero(res), m, res) - } return res } @@ -431,8 +501,8 @@ func (builder *scs) Cmp(i1, i2 frontend.Variable) frontend.Variable { // // the print will be done once the R1CS.Solve() method is executed // -// if one of the input is a variable, its value will be resolved avec R1CS.Solve() method is called -func (builder *scs) Println(a ...frontend.Variable) { +// if one of the input is a variable, its value will be resolved when R1CS.Solve() method is called +func (builder *builder) Println(a ...frontend.Variable) { var log constraint.LogEntry // prefix log line with file.go:line @@ -446,12 +516,12 @@ func (builder *scs) Println(a ...frontend.Variable) { if i > 0 { sbb.WriteByte(' ') } - if v, ok := arg.(expr.TermToRefactor); ok { + if v, ok := arg.(expr.Term); ok { sbb.WriteString("%s") // we set limits to the linear expression, so that the log printer // can evaluate it before printing it - log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.TOREFACTORMakeTerm(&builder.st.Coeffs[v.CID], v.VID)}) + log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.cs.MakeTerm(v.Coeff, v.VID)}) } else { builder.printArg(&log, &sbb, arg) } @@ -463,7 +533,7 @@ func (builder *scs) Println(a ...frontend.Variable) { builder.cs.AddLog(log) } -func (builder *scs) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { +func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { leafCount, err := schema.Walk(a, tVariable, nil) count := leafCount.Public + leafCount.Secret @@ -484,10 +554,10 @@ func (builder *scs) printArg(log *constraint.LogEntry, sbb *strings.Builder, a f sbb.WriteString(", ") } - v := tValue.Interface().(expr.TermToRefactor) + v := tValue.Interface().(expr.Term) // we set limits to the linear expression, so that the log printer // can evaluate it before printing it - log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.TOREFACTORMakeTerm(&builder.st.Coeffs[v.CID], v.VID)}) + log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.cs.MakeTerm(v.Coeff, v.VID)}) return nil } // ignoring error, printer() doesn't return errors @@ -495,6 +565,61 @@ func (builder *scs) printArg(log *constraint.LogEntry, sbb *strings.Builder, a f sbb.WriteByte('}') } -func (builder *scs) Compiler() frontend.Compiler { +func (builder *builder) Compiler() frontend.Compiler { return builder } + +func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { + + commitments := builder.cs.GetCommitments().(constraint.PlonkCommitments) + v = filterConstants(v) // TODO: @Tabaie Settle on a way to represent even constants; conventional hash? + + committed := make([]int, len(v)) + + for i, vI := range v { // TODO @Tabaie Perf; If public, just hash it + vINeg := builder.Neg(vI).(expr.Term) + committed[i] = builder.cs.GetNbConstraints() + // a constraint to enforce consistency between the commitment and committed value + // - v + comm(n) = 0 + builder.addPlonkConstraint(sparseR1C{xa: vINeg.VID, qL: vINeg.Coeff, commitment: constraint.COMMITTED}) + } + + hintId, err := cs.RegisterBsb22CommitmentComputePlaceholder(len(commitments)) + if err != nil { + return nil, err + } + + var outs []frontend.Variable + if outs, err = builder.NewHintForId(hintId, 1, v...); err != nil { + return nil, err + } + + commitmentVar := builder.Neg(outs[0]).(expr.Term) + commitmentConstraintIndex := builder.cs.GetNbConstraints() + // RHS will be provided by both prover and verifier independently, as for a public wire + builder.addPlonkConstraint(sparseR1C{xa: commitmentVar.VID, qL: commitmentVar.Coeff, commitment: constraint.COMMITMENT}) // value will be injected later + + return outs[0], builder.cs.AddCommitment(constraint.PlonkCommitment{ + HintID: hintId, + CommitmentIndex: commitmentConstraintIndex, + Committed: committed, + }) +} + +func filterConstants(v []frontend.Variable) []frontend.Variable { + res := make([]frontend.Variable, 0, len(v)) + for _, vI := range v { + if _, ok := vI.(expr.Term); ok { + res = append(res, vI) + } + } + return res +} + +func (*builder) FrontendType() frontendtype.Type { + return frontendtype.SCS +} + +func (builder *builder) SetGkrInfo(info constraint.GkrInfo) error { + return builder.cs.AddGkr(info) +} diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index f30b177666..a4e4447b17 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -20,7 +20,7 @@ import ( "fmt" "math/big" - "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/internal/utils" @@ -28,13 +28,13 @@ import ( ) // AssertIsEqual fails if i1 != i2 -func (builder *scs) AssertIsEqual(i1, i2 frontend.Variable) { +func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { - c1, i1Constant := builder.ConstantValue(i1) - c2, i2Constant := builder.ConstantValue(i2) + c1, i1Constant := builder.constantValue(i1) + c2, i2Constant := builder.constantValue(i2) if i1Constant && i2Constant { - if c1.Cmp(c2) != 0 { + if c1 != c2 { panic("i1, i2 should be equal") } return @@ -45,62 +45,100 @@ func (builder *scs) AssertIsEqual(i1, i2 frontend.Variable) { c2 = c1 } if i2Constant { - l := i1.(expr.TermToRefactor) - lc, _ := l.Unpack() - k := c2 - debug := builder.newDebugInfo("assertIsEqual", l, "+", i2, " == 0") - k.Neg(k) - _k := builder.st.CoeffID(k) - builder.addPlonkConstraint(l, builder.zero(), builder.zero(), lc, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, _k, debug) + xa := i1.(expr.Term) + c2 := builder.cs.Neg(c2) + + // xa - i2 == 0 + toAdd := sparseR1C{ + xa: xa.VID, + qL: xa.Coeff, + qC: c2, + } + + if debug.Debug { + debug := builder.newDebugInfo("assertIsEqual", xa, "==", i2) + builder.addPlonkConstraint(toAdd, debug) + } else { + builder.addPlonkConstraint(toAdd) + } return } - l := i1.(expr.TermToRefactor) - r := builder.Neg(i2).(expr.TermToRefactor) - lc, _ := l.Unpack() - rc, _ := r.Unpack() + xa := i1.(expr.Term) + xb := i2.(expr.Term) + + xb.Coeff = builder.cs.Neg(xb.Coeff) + // xa - xb == 0 + toAdd := sparseR1C{ + xa: xa.VID, + xb: xb.VID, + qL: xa.Coeff, + qR: xb.Coeff, + } - debug := builder.newDebugInfo("assertIsEqual", l, " + ", r, " == 0") - builder.addPlonkConstraint(l, r, builder.zero(), lc, rc, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdZero, debug) + if debug.Debug { + debug := builder.newDebugInfo("assertIsEqual", xa, " == ", xb) + builder.addPlonkConstraint(toAdd, debug) + } else { + builder.addPlonkConstraint(toAdd) + } } // AssertIsDifferent fails if i1 == i2 -func (builder *scs) AssertIsDifferent(i1, i2 frontend.Variable) { - builder.Inverse(builder.Sub(i1, i2)) +func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { + s := builder.Sub(i1, i2) + if c, ok := builder.constantValue(s); ok && c.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } else if t := s.(expr.Term); t.Coeff.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } + builder.Inverse(s) } // AssertIsBoolean fails if v != 0 ∥ v != 1 -func (builder *scs) AssertIsBoolean(i1 frontend.Variable) { - if c, ok := builder.ConstantValue(i1); ok { - if !(c.IsUint64() && (c.Uint64() == 0 || c.Uint64() == 1)) { - panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", c.String())) +func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { + if c, ok := builder.constantValue(i1); ok { + if !(c.IsZero() || builder.cs.IsOne(c)) { + panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", builder.cs.String(c))) } return } - t := i1.(expr.TermToRefactor) - if builder.IsBoolean(t) { + + v := i1.(expr.Term) + if builder.IsBoolean(v) { return } - builder.MarkBoolean(t) - builder.mtBooleans[int(t.CID)|t.VID<<32] = struct{}{} // TODO @gbotrel smelly fix me - debug := builder.newDebugInfo("assertIsBoolean", t, " == (0|1)") - cID, _ := t.Unpack() - var mCoef big.Int - mCoef.Neg(&builder.st.Coeffs[cID]) - mcID := builder.st.CoeffID(&mCoef) - builder.addPlonkConstraint(t, t, builder.zero(), cID, constraint.CoeffIdZero, mcID, cID, constraint.CoeffIdZero, constraint.CoeffIdZero, debug) + builder.MarkBoolean(v) + + // ensure v * (1 - v) == 0 + // that is v + -v*v == 0 + // qM = -v.Coeff*v.Coeff + qM := builder.cs.Neg(v.Coeff) + qM = builder.cs.Mul(qM, v.Coeff) + toAdd := sparseR1C{ + xa: v.VID, + qL: v.Coeff, + qM: qM, + } + if debug.Debug { + debug := builder.newDebugInfo("assertIsBoolean", v, " == (0|1)") + builder.addBoolGate(toAdd, debug) + } else { + builder.addBoolGate(toAdd) + } + } // AssertIsLessOrEqual fails if v > bound -func (builder *scs) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { +func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { switch b := bound.(type) { - case expr.TermToRefactor: - builder.mustBeLessOrEqVar(v.(expr.TermToRefactor), b) + case expr.Term: + builder.mustBeLessOrEqVar(v, b) default: - builder.mustBeLessOrEqCst(v.(expr.TermToRefactor), utils.FromInterface(b)) + builder.mustBeLessOrEqCst(v, utils.FromInterface(b)) } } -func (builder *scs) mustBeLessOrEqVar(a expr.TermToRefactor, bound expr.TermToRefactor) { +func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term) { debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -126,28 +164,33 @@ func (builder *scs) mustBeLessOrEqVar(a expr.TermToRefactor, bound expr.TermToRe t := builder.Select(boundBits[i], 0, p[i+1]) // (1 - t - ai) * ai == 0 - l := builder.Sub(1, t, aBits[i]) + l := builder.Sub(1, t, aBits[i]).(expr.Term) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i].(expr.TermToRefactor)) // this does not create a constraint - builder.addPlonkConstraint( - l.(expr.TermToRefactor), - aBits[i].(expr.TermToRefactor), - builder.zero(), - constraint.CoeffIdZero, - constraint.CoeffIdZero, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdZero, - constraint.CoeffIdZero, debug) + if ai, ok := builder.constantValue(aBits[i]); ok { + // a is constant; ensure l == 0 + l.Coeff = builder.cs.Mul(l.Coeff, ai) + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + qL: l.Coeff, + }, debug) + } else { + // l * a[i] == 0 + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + xb: aBits[i].(expr.Term).VID, + qM: l.Coeff, + }, debug) + } + } } -func (builder *scs) mustBeLessOrEqCst(a expr.TermToRefactor, bound big.Int) { +func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) { nbBits := builder.cs.FieldBitLen() @@ -159,6 +202,14 @@ func (builder *scs) mustBeLessOrEqCst(a expr.TermToRefactor, bound big.Int) { panic("AssertIsLessOrEqual: bound is too large, constraint will never be satisfied") } + if ca, ok := builder.constantValue(a); ok { + // a is constant, compare the big int values + ba := builder.cs.ToBigInt(ca) + if ba.Cmp(&bound) == 1 { + panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", ba.String(), bound.String())) + } + } + // debug info debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -191,21 +242,14 @@ func (builder *scs) mustBeLessOrEqCst(a expr.TermToRefactor, bound big.Int) { if bound.Bit(i) == 0 { // (1 - p(i+1) - ai) * ai == 0 - l := builder.Sub(1, p[i+1], aBits[i]).(expr.TermToRefactor) + l := builder.Sub(1, p[i+1], aBits[i]).(expr.Term) //l = builder.Sub(l, ).(term) - builder.addPlonkConstraint( - l, - aBits[i].(expr.TermToRefactor), - builder.zero(), - constraint.CoeffIdZero, - constraint.CoeffIdZero, - constraint.CoeffIdOne, - constraint.CoeffIdOne, - constraint.CoeffIdZero, - constraint.CoeffIdZero, - debug) - // builder.markBoolean(aBits[i].(term)) + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + xb: aBits[i].(expr.Term).VID, + qM: builder.tOne, + }, debug) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index b5d2881a86..1e6f1efb62 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -17,18 +17,18 @@ limitations under the License. package scs import ( - "fmt" "math/big" "reflect" "sort" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -40,6 +40,7 @@ import ( bn254r1cs "github.com/consensys/gnark/constraint/bn254" bw6633r1cs "github.com/consensys/gnark/constraint/bw6-633" bw6761r1cs "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/constraint/solver" tinyfieldr1cs "github.com/consensys/gnark/constraint/tinyfield" ) @@ -47,135 +48,204 @@ func NewBuilder(field *big.Int, config frontend.CompileConfig) (frontend.Builder return newBuilder(field, config), nil } -type scs struct { - cs constraint.SparseR1CS - - st cs.CoeffTable +type builder struct { + cs constraint.SparseR1CS config frontend.CompileConfig + kvstore.Store // map for recording boolean constrained variables (to not constrain them twice) - mtBooleans map[int]struct{} + mtBooleans map[expr.Term]struct{} + + // records multiplications constraint to avoid duplicate. + // see mulConstraintExist(...) + mMulInstructions map[uint64]int + + // same thing for addition gates + // see addConstraintExist(...) + mAddInstructions map[uint64]int - q *big.Int + // frequently used coefficients + tOne, tMinusOne constraint.Element + + genericGate constraint.BlueprintID + mulGate, addGate, boolGate constraint.BlueprintID + + // used to avoid repeated allocations + bufL expr.LinearExpression + bufH []constraint.LinearExpression } // initialCapacity has quite some impact on frontend performance, especially on large circuits size // we may want to add build tags to tune that -// TODO @gbotrel restore capacity option! -func newBuilder(field *big.Int, config frontend.CompileConfig) *scs { - builder := scs{ - mtBooleans: make(map[int]struct{}), - st: cs.NewCoeffTable(), - config: config, +func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { + b := builder{ + mtBooleans: make(map[expr.Term]struct{}), + mMulInstructions: make(map[uint64]int, config.Capacity/2), + mAddInstructions: make(map[uint64]int, config.Capacity/2), + config: config, + Store: kvstore.New(), + bufL: make(expr.LinearExpression, 20), } + // init hint buffer. + _ = b.hintBuffer(256) curve := utils.FieldToCurve(field) switch curve { case ecc.BLS12_377: - builder.cs = bls12377r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls12377r1cs.NewSparseR1CS(config.Capacity) case ecc.BLS12_381: - builder.cs = bls12381r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls12381r1cs.NewSparseR1CS(config.Capacity) case ecc.BN254: - builder.cs = bn254r1cs.NewSparseR1CS(config.Capacity) + b.cs = bn254r1cs.NewSparseR1CS(config.Capacity) case ecc.BW6_761: - builder.cs = bw6761r1cs.NewSparseR1CS(config.Capacity) + b.cs = bw6761r1cs.NewSparseR1CS(config.Capacity) case ecc.BW6_633: - builder.cs = bw6633r1cs.NewSparseR1CS(config.Capacity) + b.cs = bw6633r1cs.NewSparseR1CS(config.Capacity) case ecc.BLS24_315: - builder.cs = bls24315r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls24315r1cs.NewSparseR1CS(config.Capacity) case ecc.BLS24_317: - builder.cs = bls24317r1cs.NewSparseR1CS(config.Capacity) + b.cs = bls24317r1cs.NewSparseR1CS(config.Capacity) default: if field.Cmp(tinyfield.Modulus()) == 0 { - builder.cs = tinyfieldr1cs.NewSparseR1CS(config.Capacity) + b.cs = tinyfieldr1cs.NewSparseR1CS(config.Capacity) break } - panic("not implemtented") + panic("not implemented") } - builder.q = builder.cs.Field() - if builder.q.Cmp(field) != 0 { - panic("invalid modulus on cs impl") // sanity check - } + b.tOne = b.cs.One() + b.tMinusOne = b.cs.FromInterface(-1) - return &builder + b.genericGate = b.cs.AddBlueprint(&constraint.BlueprintGenericSparseR1C{}) + b.mulGate = b.cs.AddBlueprint(&constraint.BlueprintSparseR1CMul{}) + b.addGate = b.cs.AddBlueprint(&constraint.BlueprintSparseR1CAdd{}) + b.boolGate = b.cs.AddBlueprint(&constraint.BlueprintSparseR1CBool{}) + + return &b } -func (builder *scs) Field() *big.Int { +func (builder *builder) Field() *big.Int { return builder.cs.Field() } -func (builder *scs) FieldBitLen() int { +func (builder *builder) FieldBitLen() int { return builder.cs.FieldBitLen() } -// addPlonkConstraint creates a constraint of the for al+br+clr+k=0 +// TODO @gbotrel doing a 2-step refactoring for now, frontend only. need to update constraint/SparseR1C. // qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0 -func (builder *scs) addPlonkConstraint(xa, xb, xc expr.TermToRefactor, qL, qR, qM1, qM2, qO, qC int, debug ...constraint.DebugInfo) { - // TODO @gbotrel the signature of this function is odd.. and confusing. need refactor. - // TODO @gbotrel restore debug info - // if len(debugID) > 0 { - // builder.MDebug[len(builder.Constraints)] = debugID[0] - // } else if debug.Debug { - // builder.MDebug[len(builder.Constraints)] = constraint.NewDebugInfo("") - // } - - xa.SetCoeffID(qL) - xb.SetCoeffID(qR) - xc.SetCoeffID(qO) - - u := xa - v := xb - u.SetCoeffID(qM1) - v.SetCoeffID(qM2) - L := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[xa.CID], xa.VID) - R := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[xb.CID], xb.VID) - O := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[xc.CID], xc.VID) - U := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[u.CID], u.VID) - V := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[v.CID], v.VID) - K := builder.TOREFACTORMakeTerm(&builder.st.Coeffs[qC], 0) - K.MarkConstant() - builder.cs.AddConstraint(constraint.SparseR1C{L: L, R: R, O: O, M: [2]constraint.Term{U, V}, K: K.CoeffID()}, debug...) +type sparseR1C struct { + xa, xb, xc int // wires + qL, qR, qO, qM, qC constraint.Element // coefficients + commitment constraint.CommitmentConstraint +} + +// a * b == c +func (builder *builder) addMulGate(a, b, c expr.Term) { + qM := builder.cs.Mul(a.Coeff, b.Coeff) + QM := builder.cs.AddCoeff(qM) + + builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(a.VID), + XB: uint32(b.VID), + XC: uint32(c.VID), + QM: QM, + QO: constraint.CoeffIdMinusOne, + }, builder.mulGate) +} + +// a + b + k == c +func (builder *builder) addAddGate(a, b expr.Term, xc uint32, k constraint.Element) { + qL := builder.cs.AddCoeff(a.Coeff) + qR := builder.cs.AddCoeff(b.Coeff) + qC := builder.cs.AddCoeff(k) + + builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(a.VID), + XB: uint32(b.VID), + XC: xc, + QL: qL, + QR: qR, + QC: qC, + QO: constraint.CoeffIdMinusOne, + }, builder.addGate) +} + +func (builder *builder) addBoolGate(c sparseR1C, debugInfo ...constraint.DebugInfo) { + QL := builder.cs.AddCoeff(c.qL) + QM := builder.cs.AddCoeff(c.qM) + + cID := builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(c.xa), + QL: QL, + QM: QM}, + builder.boolGate) + if debug.Debug && len(debugInfo) == 1 { + builder.cs.AttachDebugInfo(debugInfo[0], []int{cID}) + } +} + +// addPlonkConstraint adds a sparseR1C to the underlying constraint system +func (builder *builder) addPlonkConstraint(c sparseR1C, debugInfo ...constraint.DebugInfo) { + if !c.qM.IsZero() && (c.xa == 0 || c.xb == 0) { + // TODO this is internal but not easy to detect; if qM is set, but one or both of xa / xb is not, + // since wireID == 0 is a valid wire, it may trigger unexpected behavior. + log := logger.Logger() + log.Warn().Msg("adding a plonk constraint with qM set but xa or xb == 0 (wire 0)") + } + QL := builder.cs.AddCoeff(c.qL) + QR := builder.cs.AddCoeff(c.qR) + QO := builder.cs.AddCoeff(c.qO) + QM := builder.cs.AddCoeff(c.qM) + QC := builder.cs.AddCoeff(c.qC) + + cID := builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(c.xa), + XB: uint32(c.xb), + XC: uint32(c.xc), + QL: QL, + QR: QR, + QO: QO, + QM: QM, + QC: QC, Commitment: c.commitment}, builder.genericGate) + if debug.Debug && len(debugInfo) == 1 { + builder.cs.AttachDebugInfo(debugInfo[0], []int{cID}) + } } // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets // the wire's id to the number of wires, and returns it -func (builder *scs) newInternalVariable() expr.TermToRefactor { +func (builder *builder) newInternalVariable() expr.Term { idx := builder.cs.AddInternalVariable() - return expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + return expr.NewTerm(idx, builder.tOne) } // PublicVariable creates a new Public Variable -func (builder *scs) PublicVariable(f schema.LeafInfo) frontend.Variable { +func (builder *builder) PublicVariable(f schema.LeafInfo) frontend.Variable { idx := builder.cs.AddPublicVariable(f.FullName()) - return expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + return expr.NewTerm(idx, builder.tOne) } // SecretVariable creates a new Secret Variable -func (builder *scs) SecretVariable(f schema.LeafInfo) frontend.Variable { +func (builder *builder) SecretVariable(f schema.LeafInfo) frontend.Variable { idx := builder.cs.AddSecretVariable(f.FullName()) - return expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + return expr.NewTerm(idx, builder.tOne) } // reduces redundancy in linear expression // It factorizes Variable that appears multiple times with != coeff Ids // To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID -func (builder *scs) reduce(l expr.LinearExpressionToRefactor) expr.LinearExpressionToRefactor { +func (builder *builder) reduce(l expr.LinearExpression) expr.LinearExpression { // ensure our linear expression is sorted, by visibility and by Variable ID sort.Sort(l) - c := new(big.Int) for i := 1; i < len(l); i++ { - pcID, pvID := l[i-1].Unpack() - ccID, cvID := l[i].Unpack() - if pvID == cvID { + if l[i-1].VID == l[i].VID { // we have redundancy - c.Add(&builder.st.Coeffs[pcID], &builder.st.Coeffs[ccID]) - c.Mod(c, builder.q) - l[i-1].SetCoeffID(builder.st.CoeffID(c)) + l[i-1].Coeff = builder.cs.Add(l[i-1].Coeff, l[i].Coeff) l = append(l[:i], l[i+1:]...) i-- } @@ -183,33 +253,28 @@ func (builder *scs) reduce(l expr.LinearExpressionToRefactor) expr.LinearExpress return l } -// to handle wires that don't exist (=coef 0) in a sparse constraint -func (builder *scs) zero() expr.TermToRefactor { - var a expr.TermToRefactor - return a -} - // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. -func (builder *scs) IsBoolean(v frontend.Variable) bool { - if b, ok := builder.ConstantValue(v); ok { - return b.IsUint64() && b.Uint64() <= 1 +func (builder *builder) IsBoolean(v frontend.Variable) bool { + if b, ok := builder.constantValue(v); ok { + return (b.IsZero() || builder.cs.IsOne(b)) } - _, ok := builder.mtBooleans[int(v.(expr.TermToRefactor).CID|(int(v.(expr.TermToRefactor).VID)<<32))] // TODO @gbotrel fixme this is sketchy + _, ok := builder.mtBooleans[v.(expr.Term)] return ok } // MarkBoolean sets (but do not constraint!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. -func (builder *scs) MarkBoolean(v frontend.Variable) { - if b, ok := builder.ConstantValue(v); ok { - if !(b.IsUint64() && b.Uint64() <= 1) { +func (builder *builder) MarkBoolean(v frontend.Variable) { + if _, ok := builder.constantValue(v); ok { + if !builder.IsBoolean(v) { panic("MarkBoolean called a non-boolean constant") } + return } - builder.mtBooleans[int(v.(expr.TermToRefactor).CID|(int(v.(expr.TermToRefactor).VID)<<32))] = struct{}{} // TODO @gbotrel fixme this is sketchy + builder.mtBooleans[v.(expr.Term)] = struct{}{} } var tVariable reflect.Type @@ -218,7 +283,7 @@ func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() } -func (builder *scs) Compile() (constraint.ConstraintSystem, error) { +func (builder *builder) Compile() (constraint.ConstraintSystem, error) { log := logger.Logger() log.Info(). Int("nbConstraints", builder.cs.GetNbConstraints()). @@ -236,21 +301,32 @@ func (builder *scs) Compile() (constraint.ConstraintSystem, error) { return builder.cs, nil } -// ConstantValue returns the big.Int value of v. It -// panics if v.IsConstant() == false -func (builder *scs) ConstantValue(v frontend.Variable) (*big.Int, bool) { - switch t := v.(type) { - case expr.TermToRefactor: +// ConstantValue returns the big.Int value of v. +// Will panic if v.IsConstant() == false +func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { + coeff, ok := builder.constantValue(v) + if !ok { return nil, false - default: - res := utils.FromInterface(t) - return &res, true } + return builder.cs.ToBigInt(coeff), true } -func (builder *scs) TOREFACTORMakeTerm(c *big.Int, vID int) constraint.Term { - cc := builder.cs.FromInterface(c) - return builder.cs.MakeTerm(&cc, vID) +func (builder *builder) constantValue(v frontend.Variable) (constraint.Element, bool) { + if _, ok := v.(expr.Term); ok { + return constraint.Element{}, false + } + return builder.cs.FromInterface(v), true +} + +func (builder *builder) hintBuffer(size int) []constraint.LinearExpression { + if cap(builder.bufH) < size { + builder.bufH = make([]constraint.LinearExpression, 2*size) + for i := 0; i < len(builder.bufH); i++ { + builder.bufH[i] = make(constraint.LinearExpression, 1) + } + } + + return builder.bufH[:size] } // NewHint initializes internal variables whose value will be evaluated using @@ -265,24 +341,31 @@ func (builder *scs) TOREFACTORMakeTerm(c *big.Int, vID int) constraint.Term { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (builder *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (builder *builder) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(f, solver.GetHintID(f), nbOutputs, inputs...) +} + +func (builder *builder) NewHintForId(id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + return builder.newHint(nil, id, nbOutputs, inputs...) +} - hintInputs := make([]constraint.LinearExpression, len(inputs)) +func (builder *builder) newHint(f solver.Hint, id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + hintInputs := builder.hintBuffer(len(inputs)) // ensure inputs are set and pack them in a []uint64 for i, in := range inputs { switch t := in.(type) { - case expr.TermToRefactor: - hintInputs[i] = constraint.LinearExpression{builder.TOREFACTORMakeTerm(&builder.st.Coeffs[t.CID], t.VID)} + case expr.Term: + hintInputs[i][0] = builder.cs.MakeTerm(t.Coeff, t.VID) default: - c := utils.FromInterface(in) - term := builder.TOREFACTORMakeTerm(&c, 0) + c := builder.cs.FromInterface(in) + term := builder.cs.MakeTerm(c, 0) term.MarkConstant() - hintInputs[i] = constraint.LinearExpression{term} + hintInputs[i][0] = term } } - internalVariables, err := builder.cs.AddSolverHint(f, hintInputs, nbOutputs) + internalVariables, err := builder.cs.AddSolverHint(f, id, hintInputs, nbOutputs) if err != nil { return nil, err } @@ -290,100 +373,320 @@ func (builder *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.V // make the variables res := make([]frontend.Variable, len(internalVariables)) for i, idx := range internalVariables { - res[i] = expr.NewTermToRefactor(idx, constraint.CoeffIdOne) + res[i] = expr.NewTerm(idx, builder.tOne) } return res, nil } // returns in split into a slice of compiledTerm and the sum of all constants in in as a bigInt -func (builder *scs) filterConstantSum(in []frontend.Variable) (expr.LinearExpressionToRefactor, big.Int) { - res := make(expr.LinearExpressionToRefactor, 0, len(in)) - var b big.Int +func (builder *builder) filterConstantSum(in []frontend.Variable) (expr.LinearExpression, constraint.Element) { + var res expr.LinearExpression + if len(in) <= cap(builder.bufL) { + // we can use the temp buffer + res = builder.bufL[:0] + } else { + res = make(expr.LinearExpression, 0, len(in)) + } + + b := constraint.Element{} for i := 0; i < len(in); i++ { - switch t := in[i].(type) { - case expr.TermToRefactor: - res = append(res, t) - default: - n := utils.FromInterface(t) - b.Add(&b, &n) + if c, ok := builder.constantValue(in[i]); ok { + b = builder.cs.Add(b, c) + } else { + if inTerm := in[i].(expr.Term); !inTerm.Coeff.IsZero() { + // add only term if coefficient is not zero. + res = append(res, in[i].(expr.Term)) + } } } return res, b } -// returns in split into a slice of compiledTerm and the product of all constants in in as a bigInt -func (builder *scs) filterConstantProd(in []frontend.Variable) (expr.LinearExpressionToRefactor, big.Int) { - res := make(expr.LinearExpressionToRefactor, 0, len(in)) - var b big.Int - b.SetInt64(1) +// returns in split into a slice of compiledTerm and the product of all constants in in as a coeff +func (builder *builder) filterConstantProd(in []frontend.Variable) (expr.LinearExpression, constraint.Element) { + var res expr.LinearExpression + if len(in) <= cap(builder.bufL) { + // we can use the temp buffer + res = builder.bufL[:0] + } else { + res = make(expr.LinearExpression, 0, len(in)) + } + + b := builder.tOne for i := 0; i < len(in); i++ { - switch t := in[i].(type) { - case expr.TermToRefactor: - res = append(res, t) - default: - n := utils.FromInterface(t) - b.Mul(&b, &n).Mod(&b, builder.q) + if c, ok := builder.constantValue(in[i]); ok { + b = builder.cs.Mul(b, c) + } else { + res = append(res, in[i].(expr.Term)) } } return res, b } -func (builder *scs) splitSum(acc expr.TermToRefactor, r expr.LinearExpressionToRefactor) expr.TermToRefactor { - +func (builder *builder) splitSum(acc expr.Term, r expr.LinearExpression, k *constraint.Element) expr.Term { // floor case if len(r) == 0 { + if k != nil { + // we need to return acc + k + o, found := builder.addConstraintExist(acc, expr.Term{}, *k) + if !found { + o = builder.newInternalVariable() + builder.addAddGate(acc, expr.Term{}, uint32(o.VID), *k) + } + + return o + } return acc } - cl, _ := acc.Unpack() - cr, _ := r[0].Unpack() - o := builder.newInternalVariable() - builder.addPlonkConstraint(acc, r[0], o, cl, cr, constraint.CoeffIdZero, constraint.CoeffIdZero, constraint.CoeffIdMinusOne, constraint.CoeffIdZero) - return builder.splitSum(o, r[1:]) + // constraint to add: acc + r[0] (+ k) == o + qC := constraint.Element{} + if k != nil { + qC = *k + } + o, found := builder.addConstraintExist(acc, r[0], qC) + if !found { + o = builder.newInternalVariable() + builder.addAddGate(acc, r[0], uint32(o.VID), qC) + } + + return builder.splitSum(o, r[1:], nil) } -func (builder *scs) splitProd(acc expr.TermToRefactor, r expr.LinearExpressionToRefactor) expr.TermToRefactor { +// addConstraintExist check if we recorded a constraint in the form +// q1*xa + q2*xb + qC - xc == 0 +// +// if we find one, this function returns the xc wire with the correct coefficients. +// if we don't, and no previous addition was recorded with xa and xb, add an entry in the map +// (this assumes that the caller will add a constraint just after this call if it's not found!) +// +// idea: +// 1. take (xa | (xb << 32)) as a identifier of an addition that used wires xa and xb. +// 2. look for an entry in builder.mAddConstraints for a previously added constraint that matches. +// 3. if so, check that the coefficients matches and we can re-use xc wire. +// +// limitations: +// 1. for efficiency, we just store the first addition that occurred with with xa and xb; +// so if we do 2*xa + 3*xb == c, then want to compute xa + xb == d multiple times, the compiler is +// not going to catch these duplicates. +// 2. this piece of code assumes some behavior from constraint/ package (like coeffIDs, or append-style +// constraint management) +func (builder *builder) addConstraintExist(a, b expr.Term, k constraint.Element) (expr.Term, bool) { + // ensure deterministic combined identifier; + if a.VID < b.VID { + a, b = b, a + } + h := uint64(a.WireID()) | uint64(b.WireID()<<32) + + if iID, ok := builder.mAddInstructions[h]; ok { + // if we do custom gates with slices in the constraint + // should use a shared object here to avoid allocs. + var c constraint.SparseR1C + + // seems likely we have a fit, let's double check + inst := builder.cs.GetInstruction(iID) + // we know the blueprint we added it. + blueprint := constraint.BlueprintSparseR1CAdd{} + blueprint.DecompressSparseR1C(&c, inst) + + // qO == -1 + if a.WireID() == int(c.XB) { + a, b = b, a // ensure a is in qL + } + + tk := builder.cs.MakeTerm(k, 0) + if tk.CoeffID() != int(c.QC) { + // the constant part of the addition differs, no point going forward + // since we will need to add a new constraint anyway. + return expr.Term{}, false + } + + // check that the coeff matches + qL := a.Coeff + qR := b.Coeff + ta := builder.cs.MakeTerm(qL, 0) + tb := builder.cs.MakeTerm(qR, 0) + if int(c.QL) != ta.CoeffID() || int(c.QR) != tb.CoeffID() { + if !k.IsZero() { + // may be for some edge cases we could avoid adding a constraint here. + return expr.Term{}, false + } + // we recorded an addition in the form q1*a + q2*b == c + // we want to record a new one in the form q3*a + q4*b == n*c + // question is; can we re-use c to avoid introducing a new wire & new constraint + // this is possible only if n == q3/q1 == q4/q2, that is, q3q2 == q1q4 + q1 := builder.cs.GetCoefficient(int(c.QL)) + q2 := builder.cs.GetCoefficient(int(c.QR)) + q3 := qL + q4 := qR + q3 = builder.cs.Mul(q3, q2) + q1 = builder.cs.Mul(q1, q4) + if q1 == q3 { + // no need to introduce a new constraint; + // compute n, the coefficient for the output wire + q2, ok = builder.cs.Inverse(q2) + if !ok { + panic("div by 0") // shouldn't happen + } + q2 = builder.cs.Mul(q2, q4) + return expr.NewTerm(int(c.XC), q2), true + } + // we will need an additional constraint + return expr.Term{}, false + } + + // we found the same constraint! + return expr.NewTerm(int(c.XC), builder.tOne), true + } + // we are going to add this constraint, so we mark it. + // ! assumes the caller add an instruction immediately after the call to this function + builder.mAddInstructions[h] = builder.cs.GetNbInstructions() + return expr.Term{}, false +} + +// mulConstraintExist check if we recorded a constraint in the form +// qM*xa*xb - xc == 0 +// +// if we find one, this function returns the xc wire with the correct coefficients. +// if we don't, and no previous multiplication was recorded with xa and xb, add an entry in the map +// (this assumes that the caller will add a constraint just after this call if it's not found!) +// +// idea: +// 1. take (xa | (xb << 32)) as a identifier of a multiplication that used wires xa and xb. +// 2. look for an entry in builder.mMulConstraints for a previously added constraint that matches. +// 3. if so, compute correct coefficient N for xc wire that matches qM'*xa*xb - N*xc == 0 +// +// limitations: +// 1. this piece of code assumes some behavior from constraint/ package (like coeffIDs, or append-style +// constraint management) +func (builder *builder) mulConstraintExist(a, b expr.Term) (expr.Term, bool) { + // ensure deterministic combined identifier; + if a.VID < b.VID { + a, b = b, a + } + h := uint64(a.WireID()) | uint64(b.WireID()<<32) + + if iID, ok := builder.mMulInstructions[h]; ok { + // if we do custom gates with slices in the constraint + // should use a shared object here to avoid allocs. + var c constraint.SparseR1C + + // seems likely we have a fit, let's double check + inst := builder.cs.GetInstruction(iID) + // we know the blueprint we added it. + blueprint := constraint.BlueprintSparseR1CMul{} + blueprint.DecompressSparseR1C(&c, inst) + + // qO == -1 + + if a.WireID() == int(c.XB) { + a, b = b, a // ensure a is in qL + } + // recompute the qM coeff and check that it matches; + qM := builder.cs.Mul(a.Coeff, b.Coeff) + tm := builder.cs.MakeTerm(qM, 0) + if int(c.QM) != tm.CoeffID() { + // so we wanted to compute + // N * xC == qM*xA*xB + // but found a constraint + // xC == qM'*xA*xB + // the coefficient for our resulting wire is different; + // N = qM / qM' + N := builder.cs.GetCoefficient(int(c.QM)) + N, ok := builder.cs.Inverse(N) + if !ok { + panic("div by 0") // sanity check. + } + N = builder.cs.Mul(N, qM) + + return expr.NewTerm(int(c.XC), N), true + } + + // we found the exact same constraint + return expr.NewTerm(int(c.XC), builder.tOne), true + } + + // we are going to add this constraint, so we mark it. + // ! assumes the caller add an instruction immediately after the call to this function + builder.mMulInstructions[h] = builder.cs.GetNbInstructions() + return expr.Term{}, false +} + +func (builder *builder) splitProd(acc expr.Term, r expr.LinearExpression) expr.Term { // floor case if len(r) == 0 { return acc } + // we want to add a constraint such that acc * r[0] == o + // let's check if we didn't already constrain a similar product + o, found := builder.mulConstraintExist(acc, r[0]) + + if !found { + // constraint to add: acc * r[0] == o + o = builder.newInternalVariable() + builder.addMulGate(acc, r[0], o) + } - cl, _ := acc.Unpack() - cr, _ := r[0].Unpack() - o := builder.newInternalVariable() - builder.addPlonkConstraint(acc, r[0], o, constraint.CoeffIdZero, constraint.CoeffIdZero, cl, cr, constraint.CoeffIdMinusOne, constraint.CoeffIdZero) return builder.splitProd(o, r[1:]) } -func (builder *scs) Commit(v ...frontend.Variable) (frontend.Variable, error) { - return nil, fmt.Errorf("not implemented") -} - // newDebugInfo this is temporary to restore debug logs // something more like builder.sprintf("my message %le %lv", l0, l1) // to build logs for both debug and println // and append some program location.. (see other todo in debug_info.go) -func (builder *scs) newDebugInfo(errName string, in ...interface{}) constraint.DebugInfo { +func (builder *builder) newDebugInfo(errName string, in ...interface{}) constraint.DebugInfo { for i := 0; i < len(in); i++ { // for inputs that are LinearExpressions or Term, we need to "Make" them in the backend. // TODO @gbotrel this is a duplicate effort with adding a constraint and should be taken care off switch t := in[i].(type) { - case *expr.LinearExpressionToRefactor, expr.LinearExpressionToRefactor: + case *expr.LinearExpression, expr.LinearExpression: // shouldn't happen - case expr.TermToRefactor: - in[i] = builder.TOREFACTORMakeTerm(&builder.st.Coeffs[t.CID], t.VID) - case *expr.TermToRefactor: - in[i] = builder.TOREFACTORMakeTerm(&builder.st.Coeffs[t.CID], t.VID) - case constraint.Coeff: - in[i] = builder.cs.String(&t) - case *constraint.Coeff: + case expr.Term: + in[i] = builder.cs.MakeTerm(t.Coeff, t.VID) + case *expr.Term: + in[i] = builder.cs.MakeTerm(t.Coeff, t.VID) + case constraint.Element: in[i] = builder.cs.String(t) + case *constraint.Element: + in[i] = builder.cs.String(*t) } } return builder.cs.NewDebugInfo(errName, in...) } + +func (builder *builder) Defer(cb func(frontend.API) error) { + circuitdefer.Put(builder, cb) +} + +// AddInstruction is used to add custom instructions to the constraint system. +func (builder *builder) AddInstruction(bID constraint.BlueprintID, calldata []uint32) []uint32 { + return builder.cs.AddInstruction(bID, calldata) +} + +// AddBlueprint adds a custom blueprint to the constraint system. +func (builder *builder) AddBlueprint(b constraint.Blueprint) constraint.BlueprintID { + return builder.cs.AddBlueprint(b) +} + +func (builder *builder) InternalVariable(wireID uint32) frontend.Variable { + return expr.NewTerm(int(wireID), builder.tOne) +} + +// ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable +// ! Experimental: use in conjunction with constraint.CustomizableSystem +func (builder *builder) ToCanonicalVariable(v frontend.Variable) frontend.CanonicalVariable { + switch t := v.(type) { + case expr.Term: + return builder.cs.MakeTerm(t.Coeff, t.VID) + default: + c := builder.cs.FromInterface(v) + term := builder.cs.MakeTerm(c, 0) + term.MarkConstant() + return term + } +} diff --git a/frontend/cs/scs/duplicate_test.go b/frontend/cs/scs/duplicate_test.go new file mode 100644 index 0000000000..cf7ad8f169 --- /dev/null +++ b/frontend/cs/scs/duplicate_test.go @@ -0,0 +1,156 @@ +package scs_test + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" +) + +type circuitDupAdd struct { + A, B frontend.Variable + R frontend.Variable +} + +func (c *circuitDupAdd) Define(api frontend.API) error { + + f := api.Add(c.A, c.B) // 1 constraint + f = api.Add(c.A, c.B, f) // 1 constraint + f = api.Add(c.A, c.B, f) // 1 constraint + + d := api.Add(api.Mul(c.A, 3), api.Mul(3, c.B)) // 3a + 3b --> 3 (a + b) shouldn't add a constraint. + e := api.Mul(api.Add(c.A, c.B), 3) // no constraints + + api.AssertIsEqual(f, e) // 1 constraint + api.AssertIsEqual(d, f) // 1 constraint + api.AssertIsEqual(c.R, e) // 1 constraint + + return nil +} + +func TestDuplicateAdd(t *testing.T) { + assert := require.New(t) + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuitDupAdd{}) + assert.NoError(err) + + assert.Equal(6, ccs.GetNbConstraints(), "comparing expected number of constraints") + + w, err := frontend.NewWitness(&circuitDupAdd{ + A: 13, + B: 42, + R: 165, + }, ecc.BN254.ScalarField()) + assert.NoError(err) + + _, err = ccs.Solve(w) + assert.NoError(err, "solving failed") +} + +type circuitDupMul struct { + A, B frontend.Variable + R1, R2 frontend.Variable +} + +func (c *circuitDupMul) Define(api frontend.API) error { + + f := api.Mul(c.A, c.B) // 1 constraint + f = api.Mul(c.A, c.B, f) // 1 constraint + f = api.Mul(c.A, c.B, f) // 1 constraint + // f == (a*b)**3 + + d := api.Mul(api.Mul(c.A, 2), api.Mul(3, c.B)) // no constraints + e := api.Mul(api.Mul(c.A, c.B), 1) // no constraints + e = api.Mul(e, e) // e**2 (no constraints) + e = api.Mul(e, api.Mul(c.A, c.B), 1) // e**3 (no constraints) + + api.AssertIsEqual(f, e) // 1 constraint + api.AssertIsEqual(d, c.R1) // 1 constraint + api.AssertIsEqual(c.R2, e) // 1 constraint + + return nil +} + +func TestDuplicateMul(t *testing.T) { + assert := require.New(t) + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuitDupMul{}) + assert.NoError(err) + + assert.Equal(6, ccs.GetNbConstraints(), "comparing expected number of constraints") + + w, err := frontend.NewWitness(&circuitDupMul{ + A: 13, + B: 42, + R1: (13 * 2) * (42 * 3), + R2: (13 * 42) * (13 * 42) * (13 * 42), + }, ecc.BN254.ScalarField()) + assert.NoError(err) + + _, err = ccs.Solve(w) + assert.NoError(err, "solving failed") +} + +type IssueDiv0Circuit struct { + A1, B1 frontend.Variable + A2, B2 frontend.Variable + A3, B3 frontend.Variable + A4, B4 frontend.Variable + + Res1, Res2, Res3, Res4, Res5, Res6, Res7, Res8 frontend.Variable +} + +func (c *IssueDiv0Circuit) Define(api frontend.API) error { + // case 1 + t1 := api.Add(api.Mul(0, c.A1), api.Mul(4, c.B1), 0) + t2 := api.Add(api.Mul(0, c.A1), api.Mul(5, c.B1), 0) + + // case 2 + t3 := api.Add(api.Mul(4, c.A2), api.Mul(0, c.B2), 0) + t4 := api.Add(api.Mul(5, c.A2), api.Mul(0, c.B2), 0) + + // case 3 + t5 := api.Add(api.Mul(0, c.A3), api.Mul(0, c.B3), 0) + t6 := api.Add(api.Mul(0, c.A3), api.Mul(5, c.B3), 0) + + // case 4 + t7 := api.Add(api.Mul(0, c.A4), api.Mul(0, c.B4), 0) + t8 := api.Add(api.Mul(5, c.A4), api.Mul(0, c.B4), 0) + + // test solver + api.AssertIsEqual(t1, c.Res1) + api.AssertIsEqual(t2, c.Res2) + api.AssertIsEqual(t3, c.Res3) + api.AssertIsEqual(t4, c.Res4) + api.AssertIsEqual(t5, c.Res5) + api.AssertIsEqual(t6, c.Res6) + api.AssertIsEqual(t7, c.Res7) + api.AssertIsEqual(t8, c.Res8) + return nil +} + +func TestExistDiv0(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &IssueDiv0Circuit{}) + if err != nil { + t.Fatal(err) + } + assert.NoError(err) + w, err := frontend.NewWitness(&IssueDiv0Circuit{ + A1: 11, B1: 21, + A2: 11, B2: 21, + A3: 11, B3: 21, + A4: 11, B4: 21, + Res1: 84, Res2: 105, + Res3: 44, Res4: 55, + Res5: 0, Res6: 105, + Res7: 0, Res8: 55, + }, ecc.BN254.ScalarField()) + assert.NoError(err) + solution, err := ccs.Solve(w) + assert.NoError(err) + _ = solution +} diff --git a/frontend/internal/expr/linear_expression.go b/frontend/internal/expr/linear_expression.go index 23efcbc146..682bf0985d 100644 --- a/frontend/internal/expr/linear_expression.go +++ b/frontend/internal/expr/linear_expression.go @@ -4,44 +4,20 @@ import ( "github.com/consensys/gnark/constraint" ) -// TODO @gbotrel --> storing a UUID in the linear expressions would enable better perf -// in the frontends -> check a linear expression is boolean, or has been converted to a -// "backend" constraint.LinearExpresion ... and avoid duplicating work would be interesting. - type LinearExpression []Term -func (l LinearExpression) Clone() LinearExpression { - res := make(LinearExpression, len(l)) - copy(res, l) - return res -} - // NewLinearExpression helper to initialize a linear expression with one term -func NewLinearExpression(vID int, cID constraint.Coeff) LinearExpression { +func NewLinearExpression(vID int, cID constraint.Element) LinearExpression { return LinearExpression{Term{Coeff: cID, VID: vID}} } -func NewTerm(vID int, cID constraint.Coeff) Term { - return Term{Coeff: cID, VID: vID} -} - -type Term struct { - VID int - Coeff constraint.Coeff -} - -func (t *Term) SetCoeff(c constraint.Coeff) { - t.Coeff = c -} -func (t Term) WireID() int { - return t.VID -} - -func (t Term) HashCode() uint64 { - return t.Coeff[0]*29 + uint64(t.VID<<12) +func (l LinearExpression) Clone() LinearExpression { + res := make(LinearExpression, len(l)) + copy(res, l) + return res } -// Len return the lenght of the Variable (implements Sort interface) +// Len return the length of the Variable (implements Sort interface) func (l LinearExpression) Len() int { return len(l) } diff --git a/frontend/internal/expr/linear_expression_scs_torefactor.go b/frontend/internal/expr/linear_expression_scs_torefactor.go deleted file mode 100644 index 6547c63bb5..0000000000 --- a/frontend/internal/expr/linear_expression_scs_torefactor.go +++ /dev/null @@ -1,78 +0,0 @@ -package expr - -type LinearExpressionToRefactor []TermToRefactor - -func (l LinearExpressionToRefactor) Clone() LinearExpressionToRefactor { - res := make(LinearExpressionToRefactor, len(l)) - copy(res, l) - return res -} - -func NewTermToRefactor(vID, cID int) TermToRefactor { - return TermToRefactor{CID: cID, VID: vID} -} - -type TermToRefactor struct { - CID int - VID int -} - -func (t TermToRefactor) Unpack() (cID, vID int) { - return t.CID, t.VID -} - -func (t *TermToRefactor) SetCoeffID(cID int) { - t.CID = cID -} -func (t TermToRefactor) WireID() int { - return t.VID -} - -func (t TermToRefactor) HashCode() uint64 { - return uint64(t.CID) + uint64(t.VID<<32) -} - -// Len return the lenght of the Variable (implements Sort interface) -func (l LinearExpressionToRefactor) Len() int { - return len(l) -} - -// Equals returns true if both SORTED expressions are the same -// -// pre conditions: l and o are sorted -func (l LinearExpressionToRefactor) Equal(o LinearExpressionToRefactor) bool { - if len(l) != len(o) { - return false - } - if (l == nil) != (o == nil) { - return false - } - for i := 0; i < len(l); i++ { - if l[i] != o[i] { - return false - } - } - return true -} - -// Swap swaps terms in the Variable (implements Sort interface) -func (l LinearExpressionToRefactor) Swap(i, j int) { - l[i], l[j] = l[j], l[i] -} - -// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) -func (l LinearExpressionToRefactor) Less(i, j int) bool { - iID := l[i].WireID() - jID := l[j].WireID() - return iID < jID -} - -// HashCode returns a fast-to-compute but NOT collision resistant hash code identifier for the linear -// expression -func (l LinearExpressionToRefactor) HashCode() uint64 { - h := uint64(17) - for _, val := range l { - h = h*23 + val.HashCode() // TODO @gbotrel revisit - } - return h -} diff --git a/frontend/internal/expr/term.go b/frontend/internal/expr/term.go new file mode 100644 index 0000000000..5c9ba5877c --- /dev/null +++ b/frontend/internal/expr/term.go @@ -0,0 +1,25 @@ +package expr + +import "github.com/consensys/gnark/constraint" + +type Term struct { + VID int + Coeff constraint.Element +} + +func NewTerm(vID int, cID constraint.Element) Term { + return Term{Coeff: cID, VID: vID} +} + +func (t *Term) SetCoeff(c constraint.Element) { + t.Coeff = c +} + +// TODO @gbotrel make that return a uint32 +func (t Term) WireID() int { + return t.VID +} + +func (t Term) HashCode() uint64 { + return t.Coeff[0]*29 + uint64(t.VID<<12) +} diff --git a/frontend/schema/schema.go b/frontend/schema/schema.go index 0f41a7ad22..f2903ba65a 100644 --- a/frontend/schema/schema.go +++ b/frontend/schema/schema.go @@ -361,7 +361,11 @@ func parse(r []Field, input interface{}, target reflect.Type, parentFullName, pa val := tValue.Index(j) if val.CanAddr() && val.Addr().CanInterface() { fqn := getFullName(parentFullName, strconv.Itoa(j), "") - subFields, err = parse(subFields, val.Addr().Interface(), target, fqn, fqn, parentTagName, parentVisibility, nbPublic, nbSecret) + ival := val.Addr().Interface() + if ih, hasInitHook := ival.(InitHook); hasInitHook { + ih.GnarkInitHook() + } + subFields, err = parse(subFields, ival, target, fqn, fqn, parentTagName, parentVisibility, nbPublic, nbSecret) if err != nil { return nil, err } diff --git a/frontend/schema/schema_test.go b/frontend/schema/schema_test.go index 929f1cbb51..23a3bfb27e 100644 --- a/frontend/schema/schema_test.go +++ b/frontend/schema/schema_test.go @@ -165,6 +165,31 @@ func TestSchemaInherit(t *testing.T) { } } +type initableVariable struct { + Val []variable +} + +func (iv *initableVariable) GnarkInitHook() { + if iv.Val == nil { + iv.Val = make([]variable, 2) + } +} + +type initableCircuit struct { + X [2]initableVariable + Y []initableVariable + Z initableVariable +} + +func TestVariableInitHook(t *testing.T) { + assert := require.New(t) + + witness := &initableCircuit{Y: make([]initableVariable, 2)} + s, err := New(witness, tVariable) + assert.NoError(err) + assert.Equal(s.NbSecret, 10) // X: 2*2, Y: 2*2, Z: 2 +} + func BenchmarkLargeSchema(b *testing.B) { const n1 = 1 << 12 const n2 = 1 << 12 diff --git a/frontend/schema/walk.go b/frontend/schema/walk.go index 07d56b880a..e62a18db0e 100644 --- a/frontend/schema/walk.go +++ b/frontend/schema/walk.go @@ -91,11 +91,23 @@ func (w *walker) Slice(value reflect.Value) error { return nil } -func (w *walker) SliceElem(index int, _ reflect.Value) error { +func (w *walker) arraySliceElem(index int, v reflect.Value) error { w.path.push(LeafInfo{Visibility: w.visibility(), name: strconv.Itoa(index)}) + if v.CanAddr() && v.Addr().CanInterface() { + // TODO @gbotrel don't like that hook, undesirable side effects + // will be hard to detect; (for example calling Parse multiple times will init multiple times!) + value := v.Addr().Interface() + if ih, hasInitHook := value.(InitHook); hasInitHook { + ih.GnarkInitHook() + } + } return nil } +func (w *walker) SliceElem(index int, v reflect.Value) error { + return w.arraySliceElem(index, v) +} + // Array handles array elements found within complex structures. func (w *walker) Array(value reflect.Value) error { if value.Type() == reflect.ArrayOf(value.Len(), w.target) { @@ -103,9 +115,8 @@ func (w *walker) Array(value reflect.Value) error { } return nil } -func (w *walker) ArrayElem(index int, _ reflect.Value) error { - w.path.push(LeafInfo{Visibility: w.visibility(), name: strconv.Itoa(index)}) - return nil +func (w *walker) ArrayElem(index int, v reflect.Value) error { + return w.arraySliceElem(index, v) } // process an array or slice of leaves; since it's quite common to have large array/slices diff --git a/frontend/variable.go b/frontend/variable.go index 9f00ba5167..82d33fbc90 100644 --- a/frontend/variable.go +++ b/frontend/variable.go @@ -32,8 +32,6 @@ func IsCanonical(v Variable) bool { switch v.(type) { case expr.LinearExpression, *expr.LinearExpression, expr.Term, *expr.Term: return true - case expr.LinearExpressionToRefactor, *expr.LinearExpressionToRefactor, expr.TermToRefactor, *expr.TermToRefactor: - return true } return false } diff --git a/go.mod b/go.mod index b4455c6cc0..64a83f50cf 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,32 @@ module github.com/consensys/gnark -go 1.18 +go 1.19 require ( + github.com/bits-and-blooms/bitset v1.8.0 github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.13 - github.com/consensys/gnark-crypto v0.9.2 - github.com/fxamacker/cbor/v2 v2.4.0 + github.com/consensys/gnark-crypto v0.11.2 + github.com/fxamacker/cbor/v2 v2.5.0 github.com/google/go-cmp v0.5.9 - github.com/google/pprof v0.0.0-20230207041349-798e818bf904 + github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b github.com/leanovate/gopter v0.2.9 - github.com/rs/zerolog v1.29.0 - github.com/stretchr/testify v1.8.1 - golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb + github.com/rs/zerolog v1.30.0 + github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.12.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + golang.org/x/sys v0.11.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/pretty v0.3.1 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/crypto v0.6.0 // indirect - golang.org/x/sys v0.5.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/go.sum b/go.sum index ae6875069a..e35a6b6f46 100644 --- a/go.sum +++ b/go.sum @@ -1,70 +1,63 @@ +github.com/bits-and-blooms/bitset v1.8.0 h1:FD+XqgOZDUxxZ8hzoBFuV9+cGWY9CslN6d5MS5JVb4c= +github.com/bits-and-blooms/bitset v1.8.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.9.1 h1:mru55qKdWl3E035hAoh1jj9d7hVnYY5pfb6tmovSmII= -github.com/consensys/gnark-crypto v0.9.1/go.mod h1:a2DQL4+5ywF6safEeZFEPGRiiGbjzGFRUN2sg06VuU4= -github.com/consensys/gnark-crypto v0.9.2 h1:a4gsSAnQNgrt8dqxsd49H2rtLQPoekZCWpCmcKPRNus= -github.com/consensys/gnark-crypto v0.9.2/go.mod h1:a2DQL4+5ywF6safEeZFEPGRiiGbjzGFRUN2sg06VuU4= -github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/consensys/gnark-crypto v0.11.2 h1:GJjjtWJ+db1xGao7vTsOgAOGgjfPe7eRGPL+xxMX0qE= +github.com/consensys/gnark-crypto v0.11.2/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= -github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= +github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= -github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= +github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b h1:h9U78+dx9a4BKdQkBBos92HalKpaGKHrp+3Uo6yTodo= +github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.29.0 h1:Zes4hju04hjbvkVkOhdl2HpZa+0PmVwigmo8XoORE5w= -github.com/rs/zerolog v1.29.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= +github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb h1:PaBZQdo+iSDyHT053FjUCgZQ/9uqVwPOcl7KSWhKn6w= -golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= rsc.io/tmplfunc v0.0.3 h1:53XFQh69AfOa8Tw0Jm7t+GV7KZhOi6jzsCzTtKbMvzU= diff --git a/integration_test.go b/integration_test.go index 3024e003b2..598d9f9165 100644 --- a/integration_test.go +++ b/integration_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/backend/circuits" "github.com/consensys/gnark/test" ) @@ -55,9 +56,10 @@ func TestIntegrationAPI(t *testing.T) { assert.Run(func(assert *test.Assert) { assert.ProverSucceeded( tData.Circuit, tData.ValidAssignments[i], - test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), + test.WithSolverOpts(solver.WithHints(tData.HintFunctions...)), test.WithCurves(tData.Curves[0], tData.Curves[1:]...), - test.WithBackends(backends[0], backends[1:]...)) + test.WithBackends(backends[0], backends[1:]...), + test.WithSolidity()) }, fmt.Sprintf("valid-%d", i)) } @@ -66,7 +68,7 @@ func TestIntegrationAPI(t *testing.T) { assert.ProverFailed( tData.Circuit, tData.InvalidAssignments[i], - test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), + test.WithSolverOpts(solver.WithHints(tData.HintFunctions...)), test.WithCurves(tData.Curves[0], tData.Curves[1:]...), test.WithBackends(backends[0], backends[1:]...)) }, fmt.Sprintf("invalid-%d", i)) diff --git a/internal/backend/bls12-377/plonk/marshal.go b/internal/backend/bls12-377/plonk/marshal.go deleted file mode 100644 index 64daa4c899..0000000000 --- a/internal/backend/bls12-377/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/internal/backend/bls12-377/plonk/marshal_test.go deleted file mode 100644 index 94a69f6295..0000000000 --- a/internal/backend/bls12-377/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go deleted file mode 100644 index 489ab0dc61..0000000000 --- a/internal/backend/bls12-377/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" - "github.com/consensys/gnark/constraint/bls12-377" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go deleted file mode 100644 index e7f0b7a553..0000000000 --- a/internal/backend/bls12-377/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" - "github.com/consensys/gnark/constraint/bls12-377" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bls12-381/plonk/marshal.go b/internal/backend/bls12-381/plonk/marshal.go deleted file mode 100644 index cdb2e18749..0000000000 --- a/internal/backend/bls12-381/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bls12-381/plonk/marshal_test.go b/internal/backend/bls12-381/plonk/marshal_test.go deleted file mode 100644 index 0e0380208d..0000000000 --- a/internal/backend/bls12-381/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go deleted file mode 100644 index 8d1402e02f..0000000000 --- a/internal/backend/bls12-381/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" - "github.com/consensys/gnark/constraint/bls12-381" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go deleted file mode 100644 index e86f945c71..0000000000 --- a/internal/backend/bls12-381/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" - "github.com/consensys/gnark/constraint/bls12-381" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bls24-315/plonk/marshal.go b/internal/backend/bls24-315/plonk/marshal.go deleted file mode 100644 index 69aad24641..0000000000 --- a/internal/backend/bls24-315/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bls24-315/plonk/marshal_test.go b/internal/backend/bls24-315/plonk/marshal_test.go deleted file mode 100644 index 485366fd77..0000000000 --- a/internal/backend/bls24-315/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go deleted file mode 100644 index eccaefd544..0000000000 --- a/internal/backend/bls24-315/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" - "github.com/consensys/gnark/constraint/bls24-315" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go deleted file mode 100644 index 0bdfcb2ff9..0000000000 --- a/internal/backend/bls24-315/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" - "github.com/consensys/gnark/constraint/bls24-315" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bls24-317/plonk/marshal.go b/internal/backend/bls24-317/plonk/marshal.go deleted file mode 100644 index 5b9eb8b6f1..0000000000 --- a/internal/backend/bls24-317/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls24-317" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bls24-317/plonk/marshal_test.go b/internal/backend/bls24-317/plonk/marshal_test.go deleted file mode 100644 index 64618d989e..0000000000 --- a/internal/backend/bls24-317/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bls24-317" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bls24-317/plonk/prove.go b/internal/backend/bls24-317/plonk/prove.go deleted file mode 100644 index 162b2a6e28..0000000000 --- a/internal/backend/bls24-317/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls24-317" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" - "github.com/consensys/gnark/constraint/bls24-317" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bls24-317/plonk/setup.go b/internal/backend/bls24-317/plonk/setup.go deleted file mode 100644 index 08c3689967..0000000000 --- a/internal/backend/bls24-317/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/kzg" - "github.com/consensys/gnark/constraint/bls24-317" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bn254/groth16/utils_test.go b/internal/backend/bn254/groth16/utils_test.go deleted file mode 100644 index 4552aa5e37..0000000000 --- a/internal/backend/bn254/groth16/utils_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package groth16 - -import ( - "testing" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/stretchr/testify/assert" -) - -func assertSliceEquals[T any](t *testing.T, expected []T, seen []T) { - assert.Equal(t, len(expected), len(seen)) - for i := range expected { - assert.Equal(t, expected[i], seen[i]) - } -} - -func TestRemoveIndex(t *testing.T) { - elems := []fr.Element{{0}, {1}, {2}, {3}} - r := filter(elems, []int{1, 2}) - expected := []fr.Element{{0}, {3}} - assertSliceEquals(t, expected, r) -} diff --git a/internal/backend/bn254/plonk/marshal.go b/internal/backend/bn254/plonk/marshal.go deleted file mode 100644 index 400e50e808..0000000000 --- a/internal/backend/bn254/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bn254" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go deleted file mode 100644 index 8840ad242e..0000000000 --- a/internal/backend/bn254/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bn254" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go deleted file mode 100644 index 0f5042889d..0000000000 --- a/internal/backend/bn254/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bn254" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" - "github.com/consensys/gnark/constraint/bn254" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go deleted file mode 100644 index 1cefa3c469..0000000000 --- a/internal/backend/bn254/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" - "github.com/consensys/gnark/constraint/bn254" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bn254/plonk/solidity.go b/internal/backend/bn254/plonk/solidity.go deleted file mode 100644 index 7fddbd204f..0000000000 --- a/internal/backend/bn254/plonk/solidity.go +++ /dev/null @@ -1,907 +0,0 @@ -package plonk - -const solidityTemplate = ` -// Warning this code was contributed into gnark here: -// https://github.com/ConsenSys/gnark/pull/358 -// -// It has not been audited and is provided as-is, we make no guarantees or warranties to its safety and reliability. -// -// According to https://eprint.iacr.org/archive/2019/953/1585767119.pdf -pragma solidity ^0.8.0; -pragma experimental ABIEncoderV2; - -library PairingsBn254 { - uint256 constant q_mod = 21888242871839275222246405745257275088696311157297823662689037894645226208583; - uint256 constant r_mod = 21888242871839275222246405745257275088548364400416034343698204186575808495617; - uint256 constant bn254_b_coeff = 3; - - struct G1Point { - uint256 X; - uint256 Y; - } - - struct Fr { - uint256 value; - } - - function new_fr(uint256 fr) internal pure returns (Fr memory) { - require(fr < r_mod); - return Fr({value: fr}); - } - - function copy(Fr memory self) internal pure returns (Fr memory n) { - n.value = self.value; - } - - function assign(Fr memory self, Fr memory other) internal pure { - self.value = other.value; - } - - function inverse(Fr memory fr) internal view returns (Fr memory) { - require(fr.value != 0); - return pow(fr, r_mod-2); - } - - function add_assign(Fr memory self, Fr memory other) internal pure { - self.value = addmod(self.value, other.value, r_mod); - } - - function sub_assign(Fr memory self, Fr memory other) internal pure { - self.value = addmod(self.value, r_mod - other.value, r_mod); - } - - function mul_assign(Fr memory self, Fr memory other) internal pure { - self.value = mulmod(self.value, other.value, r_mod); - } - - function pow(Fr memory self, uint256 power) internal view returns (Fr memory) { - uint256[6] memory input = [32, 32, 32, self.value, power, r_mod]; - uint256[1] memory result; - bool success; - assembly { - success := staticcall(gas(), 0x05, input, 0xc0, result, 0x20) - } - require(success); - return Fr({value: result[0]}); - } - - // Encoding of field elements is: X[0] * z + X[1] - struct G2Point { - uint[2] X; - uint[2] Y; - } - - function P1() internal pure returns (G1Point memory) { - return G1Point(1, 2); - } - - function new_g1(uint256 x, uint256 y) internal pure returns (G1Point memory) { - return G1Point(x, y); - } - - function new_g1_checked(uint256 x, uint256 y) internal pure returns (G1Point memory) { - if (x == 0 && y == 0) { - // point of infinity is (0,0) - return G1Point(x, y); - } - - // check encoding - require(x < q_mod); - require(y < q_mod); - // check on curve - uint256 lhs = mulmod(y, y, q_mod); // y^2 - uint256 rhs = mulmod(x, x, q_mod); // x^2 - rhs = mulmod(rhs, x, q_mod); // x^3 - rhs = addmod(rhs, bn254_b_coeff, q_mod); // x^3 + b - require(lhs == rhs); - - return G1Point(x, y); - } - - function new_g2(uint256[2] memory x, uint256[2] memory y) internal pure returns (G2Point memory) { - return G2Point(x, y); - } - - function copy_g1(G1Point memory self) internal pure returns (G1Point memory result) { - result.X = self.X; - result.Y = self.Y; - } - - function P2() internal pure returns (G2Point memory) { - // for some reason ethereum expects to have c1*v + c0 form - - return G2Point( - [0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2, - 0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed], - [0x090689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b, - 0x12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa] - ); - } - - function negate(G1Point memory self) internal pure { - // The prime q in the base field F_q for G1 - if (self.Y == 0) { - require(self.X == 0); - return; - } - - self.Y = q_mod - self.Y; - } - - function point_add(G1Point memory p1, G1Point memory p2) - internal view returns (G1Point memory r) - { - point_add_into_dest(p1, p2, r); - return r; - } - - function point_add_assign(G1Point memory p1, G1Point memory p2) - internal view - { - point_add_into_dest(p1, p2, p1); - } - - function point_add_into_dest(G1Point memory p1, G1Point memory p2, G1Point memory dest) - internal view - { - if (p2.X == 0 && p2.Y == 0) { - // we add zero, nothing happens - dest.X = p1.X; - dest.Y = p1.Y; - return; - } else if (p1.X == 0 && p1.Y == 0) { - // we add into zero, and we add non-zero point - dest.X = p2.X; - dest.Y = p2.Y; - return; - } else { - uint256[4] memory input; - - input[0] = p1.X; - input[1] = p1.Y; - input[2] = p2.X; - input[3] = p2.Y; - - bool success = false; - assembly { - success := staticcall(gas(), 6, input, 0x80, dest, 0x40) - } - require(success); - } - } - - function point_sub_assign(G1Point memory p1, G1Point memory p2) - internal view - { - point_sub_into_dest(p1, p2, p1); - } - - function point_sub_into_dest(G1Point memory p1, G1Point memory p2, G1Point memory dest) - internal view - { - if (p2.X == 0 && p2.Y == 0) { - // we subtracted zero, nothing happens - dest.X = p1.X; - dest.Y = p1.Y; - return; - } else if (p1.X == 0 && p1.Y == 0) { - // we subtract from zero, and we subtract non-zero point - dest.X = p2.X; - dest.Y = q_mod - p2.Y; - return; - } else { - uint256[4] memory input; - - input[0] = p1.X; - input[1] = p1.Y; - input[2] = p2.X; - input[3] = q_mod - p2.Y; - - bool success = false; - assembly { - success := staticcall(gas(), 6, input, 0x80, dest, 0x40) - } - require(success); - } - } - - function point_mul(G1Point memory p, Fr memory s) - internal view returns (G1Point memory r) - { - point_mul_into_dest(p, s, r); - return r; - } - - function point_mul_assign(G1Point memory p, Fr memory s) - internal view - { - point_mul_into_dest(p, s, p); - } - - function point_mul_into_dest(G1Point memory p, Fr memory s, G1Point memory dest) - internal view - { - uint[3] memory input; - input[0] = p.X; - input[1] = p.Y; - input[2] = s.value; - bool success; - assembly { - success := staticcall(gas(), 7, input, 0x60, dest, 0x40) - } - require(success); - } - - function pairing(G1Point[] memory p1, G2Point[] memory p2) - internal view returns (bool) - { - require(p1.length == p2.length); - uint elements = p1.length; - uint inputSize = elements * 6; - uint[] memory input = new uint[](inputSize); - for (uint i = 0; i < elements; i++) - { - input[i * 6 + 0] = p1[i].X; - input[i * 6 + 1] = p1[i].Y; - input[i * 6 + 2] = p2[i].X[0]; - input[i * 6 + 3] = p2[i].X[1]; - input[i * 6 + 4] = p2[i].Y[0]; - input[i * 6 + 5] = p2[i].Y[1]; - } - uint[1] memory out; - bool success; - assembly { - success := staticcall(gas(), 8, add(input, 0x20), mul(inputSize, 0x20), out, 0x20) - } - require(success); - return out[0] != 0; - } - - /// Convenience method for a pairing check for two pairs. - function pairingProd2(G1Point memory a1, G2Point memory a2, G1Point memory b1, G2Point memory b2) - internal view returns (bool) - { - G1Point[] memory p1 = new G1Point[](2); - G2Point[] memory p2 = new G2Point[](2); - p1[0] = a1; - p1[1] = b1; - p2[0] = a2; - p2[1] = b2; - return pairing(p1, p2); - } -} - -library TranscriptLibrary { - uint32 constant DST_0 = 0; - uint32 constant DST_1 = 1; - uint32 constant DST_CHALLENGE = 2; - - struct Transcript { - bytes32 previous_randomness; - bytes bindings; - string name; - uint32 challenge_counter; - } - - function new_transcript() internal pure returns (Transcript memory t) { - t.challenge_counter = 0; - } - - function set_challenge_name(Transcript memory self, string memory name) internal pure { - self.name = name; - } - - function update_with_u256(Transcript memory self, uint256 value) internal pure { - self.bindings = abi.encodePacked(self.bindings, value); - } - - function update_with_fr(Transcript memory self, PairingsBn254.Fr memory value) internal pure { - self.bindings = abi.encodePacked(self.bindings, value.value); - } - - function update_with_g1(Transcript memory self, PairingsBn254.G1Point memory p) internal pure { - self.bindings = abi.encodePacked(self.bindings, p.X, p.Y); - } - - function get_encode(Transcript memory self) internal pure returns(bytes memory query) { - if (self.challenge_counter != 0) { - query = abi.encodePacked(self.name, self.previous_randomness, self.bindings); - } else { - query = abi.encodePacked(self.name, self.bindings); - } - return query; - } - function get_challenge(Transcript memory self) internal pure returns(PairingsBn254.Fr memory challenge) { - bytes32 query; - if (self.challenge_counter != 0) { - query = sha256(abi.encodePacked(self.name, self.previous_randomness, self.bindings)); - } else { - query = sha256(abi.encodePacked(self.name, self.bindings)); - } - self.challenge_counter += 1; - self.previous_randomness = query; - challenge = PairingsBn254.Fr({value: uint256(query) % PairingsBn254.r_mod}); - self.bindings = ""; - } -} - -contract PlonkVerifier { - using PairingsBn254 for PairingsBn254.G1Point; - using PairingsBn254 for PairingsBn254.G2Point; - using PairingsBn254 for PairingsBn254.Fr; - - using TranscriptLibrary for TranscriptLibrary.Transcript; - - uint256 constant STATE_WIDTH = 3; - - struct VerificationKey { - uint256 domain_size; - uint256 num_inputs; - PairingsBn254.Fr omega; // w - PairingsBn254.G1Point[STATE_WIDTH+2] selector_commitments; // STATE_WIDTH for witness + multiplication + constant - PairingsBn254.G1Point[STATE_WIDTH] permutation_commitments; // [Sσ1(x)],[Sσ2(x)],[Sσ3(x)] - PairingsBn254.Fr[STATE_WIDTH-1] permutation_non_residues; // k1, k2 - PairingsBn254.G2Point g2_x; - } - - struct Proof { - uint256[] input_values; - PairingsBn254.G1Point[STATE_WIDTH] wire_commitments; // [a(x)]/[b(x)]/[c(x)] - PairingsBn254.G1Point grand_product_commitment; // [z(x)] - PairingsBn254.G1Point[STATE_WIDTH] quotient_poly_commitments; // [t_lo]/[t_mid]/[t_hi] - PairingsBn254.Fr[STATE_WIDTH] wire_values_at_zeta; // a(zeta)/b(zeta)/c(zeta) - PairingsBn254.Fr grand_product_at_zeta_omega; // z(w*zeta) - PairingsBn254.Fr quotient_polynomial_at_zeta; // t(zeta) - PairingsBn254.Fr linearization_polynomial_at_zeta; // r(zeta) - PairingsBn254.Fr[STATE_WIDTH-1] permutation_polynomials_at_zeta; // Sσ1(zeta),Sσ2(zeta) - - PairingsBn254.G1Point opening_at_zeta_proof; // [Wzeta] - PairingsBn254.G1Point opening_at_zeta_omega_proof; // [Wzeta*omega] - } - - struct PartialVerifierState { - PairingsBn254.Fr alpha; - PairingsBn254.Fr beta; - PairingsBn254.Fr gamma; - PairingsBn254.Fr v; - PairingsBn254.Fr u; - PairingsBn254.Fr zeta; - PairingsBn254.Fr[] cached_lagrange_evals; - - PairingsBn254.G1Point cached_fold_quotient_ploy_commitments; - } - - function verify_initial( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk) internal view returns (bool) { - - require(proof.input_values.length == vk.num_inputs, "not match"); - require(vk.num_inputs >= 1, "inv input"); - - TranscriptLibrary.Transcript memory t = TranscriptLibrary.new_transcript(); - t.set_challenge_name("gamma"); - for (uint256 i = 0; i < vk.permutation_commitments.length; i++) { - t.update_with_g1(vk.permutation_commitments[i]); - } - // this is gnark order: Ql, Qr, Qm, Qo, Qk - // - t.update_with_g1(vk.selector_commitments[0]); - t.update_with_g1(vk.selector_commitments[1]); - t.update_with_g1(vk.selector_commitments[3]); - t.update_with_g1(vk.selector_commitments[2]); - t.update_with_g1(vk.selector_commitments[4]); - - for (uint256 i = 0; i < proof.input_values.length; i++) { - t.update_with_u256(proof.input_values[i]); - } - state.gamma = t.get_challenge(); - - t.set_challenge_name("beta"); - state.beta = t.get_challenge(); - - t.set_challenge_name("alpha"); - t.update_with_g1(proof.grand_product_commitment); - state.alpha = t.get_challenge(); - - t.set_challenge_name("zeta"); - for (uint256 i = 0; i < proof.quotient_poly_commitments.length; i++) { - t.update_with_g1(proof.quotient_poly_commitments[i]); - } - state.zeta = t.get_challenge(); - - uint256[] memory lagrange_poly_numbers = new uint256[](vk.num_inputs); - for (uint256 i = 0; i < lagrange_poly_numbers.length; i++) { - lagrange_poly_numbers[i] = i; - } - state.cached_lagrange_evals = batch_evaluate_lagrange_poly_out_of_domain( - lagrange_poly_numbers, - vk.domain_size, - vk.omega, state.zeta - ); - - bool valid = verify_quotient_poly_eval_at_zeta(state, proof, vk); - return valid; - } - - function verify_commitments( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk - ) internal view returns (bool) { - PairingsBn254.G1Point memory d = reconstruct_d(state, proof, vk); - - PairingsBn254.G1Point memory tmp_g1 = PairingsBn254.P1(); - - PairingsBn254.Fr memory aggregation_challenge = PairingsBn254.new_fr(1); - - PairingsBn254.G1Point memory commitment_aggregation = PairingsBn254.copy_g1(state.cached_fold_quotient_ploy_commitments); - PairingsBn254.Fr memory tmp_fr = PairingsBn254.new_fr(1); - - aggregation_challenge.mul_assign(state.v); - commitment_aggregation.point_add_assign(d); - - for (uint i = 0; i < proof.wire_commitments.length; i++) { - aggregation_challenge.mul_assign(state.v); - tmp_g1 = proof.wire_commitments[i].point_mul(aggregation_challenge); - commitment_aggregation.point_add_assign(tmp_g1); - } - - for (uint i = 0; i < vk.permutation_commitments.length - 1; i++) { - aggregation_challenge.mul_assign(state.v); - tmp_g1 = vk.permutation_commitments[i].point_mul(aggregation_challenge); - commitment_aggregation.point_add_assign(tmp_g1); - } - - // collect opening values - aggregation_challenge = PairingsBn254.new_fr(1); - - PairingsBn254.Fr memory aggregated_value = PairingsBn254.copy(proof.quotient_polynomial_at_zeta); - - aggregation_challenge.mul_assign(state.v); - - tmp_fr.assign(proof.linearization_polynomial_at_zeta); - tmp_fr.mul_assign(aggregation_challenge); - aggregated_value.add_assign(tmp_fr); - - for (uint i = 0; i < proof.wire_values_at_zeta.length; i++) { - aggregation_challenge.mul_assign(state.v); - - tmp_fr.assign(proof.wire_values_at_zeta[i]); - tmp_fr.mul_assign(aggregation_challenge); - aggregated_value.add_assign(tmp_fr); - } - - for (uint i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - aggregation_challenge.mul_assign(state.v); - - tmp_fr.assign(proof.permutation_polynomials_at_zeta[i]); - tmp_fr.mul_assign(aggregation_challenge); - aggregated_value.add_assign(tmp_fr); - } - tmp_fr.assign(proof.grand_product_at_zeta_omega); - tmp_fr.mul_assign(state.u); - aggregated_value.add_assign(tmp_fr); - - commitment_aggregation.point_sub_assign(PairingsBn254.P1().point_mul(aggregated_value)); - - PairingsBn254.G1Point memory pair_with_generator = commitment_aggregation; - pair_with_generator.point_add_assign(proof.opening_at_zeta_proof.point_mul(state.zeta)); - - tmp_fr.assign(state.zeta); - tmp_fr.mul_assign(vk.omega); - tmp_fr.mul_assign(state.u); - pair_with_generator.point_add_assign(proof.opening_at_zeta_omega_proof.point_mul(tmp_fr)); - - PairingsBn254.G1Point memory pair_with_x = proof.opening_at_zeta_omega_proof.point_mul(state.u); - pair_with_x.point_add_assign(proof.opening_at_zeta_proof); - pair_with_x.negate(); - - return PairingsBn254.pairingProd2(pair_with_generator, PairingsBn254.P2(), pair_with_x, vk.g2_x); - } - - function reconstruct_d( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk - ) internal view returns (PairingsBn254.G1Point memory res) { - res = PairingsBn254.copy_g1(vk.selector_commitments[STATE_WIDTH + 1]); - - PairingsBn254.G1Point memory tmp_g1 = PairingsBn254.P1(); - PairingsBn254.Fr memory tmp_fr = PairingsBn254.new_fr(0); - - // addition gates - for (uint256 i = 0; i < STATE_WIDTH; i++) { - tmp_g1 = vk.selector_commitments[i].point_mul(proof.wire_values_at_zeta[i]); - res.point_add_assign(tmp_g1); - } - - // multiplication gate - tmp_fr.assign(proof.wire_values_at_zeta[0]); - tmp_fr.mul_assign(proof.wire_values_at_zeta[1]); - tmp_g1 = vk.selector_commitments[STATE_WIDTH].point_mul(tmp_fr); - res.point_add_assign(tmp_g1); - - // z * non_res * beta + gamma + a - PairingsBn254.Fr memory grand_product_part_at_z = PairingsBn254.copy(state.zeta); - grand_product_part_at_z.mul_assign(state.beta); - grand_product_part_at_z.add_assign(proof.wire_values_at_zeta[0]); - grand_product_part_at_z.add_assign(state.gamma); - for (uint256 i = 0; i < vk.permutation_non_residues.length; i++) { - tmp_fr.assign(state.zeta); - tmp_fr.mul_assign(vk.permutation_non_residues[i]); - tmp_fr.mul_assign(state.beta); - tmp_fr.add_assign(state.gamma); - tmp_fr.add_assign(proof.wire_values_at_zeta[i+1]); - - grand_product_part_at_z.mul_assign(tmp_fr); - } - - grand_product_part_at_z.mul_assign(state.alpha); - - tmp_fr.assign(state.cached_lagrange_evals[0]); - tmp_fr.mul_assign(state.alpha); - tmp_fr.mul_assign(state.alpha); - // NOTICE - grand_product_part_at_z.sub_assign(tmp_fr); - PairingsBn254.Fr memory last_permutation_part_at_z = PairingsBn254.new_fr(1); - for (uint256 i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - tmp_fr.assign(state.beta); - tmp_fr.mul_assign(proof.permutation_polynomials_at_zeta[i]); - tmp_fr.add_assign(state.gamma); - tmp_fr.add_assign(proof.wire_values_at_zeta[i]); - - last_permutation_part_at_z.mul_assign(tmp_fr); - } - - last_permutation_part_at_z.mul_assign(state.beta); - last_permutation_part_at_z.mul_assign(proof.grand_product_at_zeta_omega); - last_permutation_part_at_z.mul_assign(state.alpha); - - // gnark implementation: add third part and sub second second part - // plonk paper implementation: add second part and sub third part - /* - tmp_g1 = proof.grand_product_commitment.point_mul(grand_product_part_at_z); - tmp_g1.point_sub_assign(vk.permutation_commitments[STATE_WIDTH - 1].point_mul(last_permutation_part_at_z)); - */ - // add to the linearization - - tmp_g1 = vk.permutation_commitments[STATE_WIDTH - 1].point_mul(last_permutation_part_at_z); - tmp_g1.point_sub_assign(proof.grand_product_commitment.point_mul(grand_product_part_at_z)); - res.point_add_assign(tmp_g1); - - generate_uv_challenge(state, proof, vk, res); - - res.point_mul_assign(state.v); - res.point_add_assign(proof.grand_product_commitment.point_mul(state.u)); - } - - // gnark v generation process: - // sha256(zeta, proof.quotient_poly_commitments, linearizedPolynomialDigest, proof.wire_commitments, vk.permutation_commitments[0..1], ) - // NOTICE: gnark use "gamma" name for v, it's not reasonable - // NOTICE: gnark use zeta^(n+2) which is a bit different with plonk paper - // generate_v_challenge(); - function generate_uv_challenge( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk, - PairingsBn254.G1Point memory linearization_point) view internal { - TranscriptLibrary.Transcript memory transcript = TranscriptLibrary.new_transcript(); - transcript.set_challenge_name("gamma"); - transcript.update_with_fr(state.zeta); - PairingsBn254.Fr memory zeta_plus_two = PairingsBn254.copy(state.zeta); - PairingsBn254.Fr memory n_plus_two = PairingsBn254.new_fr(vk.domain_size); - n_plus_two.add_assign(PairingsBn254.new_fr(2)); - zeta_plus_two = zeta_plus_two.pow(n_plus_two.value); - state.cached_fold_quotient_ploy_commitments = PairingsBn254.copy_g1(proof.quotient_poly_commitments[STATE_WIDTH-1]); - for (uint256 i = 0; i < STATE_WIDTH - 1; i++) { - state.cached_fold_quotient_ploy_commitments.point_mul_assign(zeta_plus_two); - state.cached_fold_quotient_ploy_commitments.point_add_assign(proof.quotient_poly_commitments[STATE_WIDTH - 2 - i]); - } - transcript.update_with_g1(state.cached_fold_quotient_ploy_commitments); - transcript.update_with_g1(linearization_point); - - for (uint256 i = 0; i < proof.wire_commitments.length; i++) { - transcript.update_with_g1(proof.wire_commitments[i]); - } - for (uint256 i = 0; i < vk.permutation_commitments.length - 1; i++) { - transcript.update_with_g1(vk.permutation_commitments[i]); - } - state.v = transcript.get_challenge(); - // gnark use local randomness to generate u - // we use opening_at_zeta_proof and opening_at_zeta_omega_proof - transcript.set_challenge_name("u"); - transcript.update_with_g1(proof.opening_at_zeta_proof); - transcript.update_with_g1(proof.opening_at_zeta_omega_proof); - state.u = transcript.get_challenge(); - } - - function batch_evaluate_lagrange_poly_out_of_domain( - uint256[] memory poly_nums, - uint256 domain_size, - PairingsBn254.Fr memory omega, - PairingsBn254.Fr memory at - ) internal view returns (PairingsBn254.Fr[] memory res) { - PairingsBn254.Fr memory one = PairingsBn254.new_fr(1); - PairingsBn254.Fr memory tmp_1 = PairingsBn254.new_fr(0); - PairingsBn254.Fr memory tmp_2 = PairingsBn254.new_fr(domain_size); - PairingsBn254.Fr memory vanishing_at_zeta = at.pow(domain_size); - vanishing_at_zeta.sub_assign(one); - // we can not have random point z be in domain - require(vanishing_at_zeta.value != 0); - PairingsBn254.Fr[] memory nums = new PairingsBn254.Fr[](poly_nums.length); - PairingsBn254.Fr[] memory dens = new PairingsBn254.Fr[](poly_nums.length); - // numerators in a form omega^i * (z^n - 1) - // denoms in a form (z - omega^i) * N - for (uint i = 0; i < poly_nums.length; i++) { - tmp_1 = omega.pow(poly_nums[i]); // power of omega - nums[i].assign(vanishing_at_zeta); - nums[i].mul_assign(tmp_1); - - dens[i].assign(at); // (X - omega^i) * N - dens[i].sub_assign(tmp_1); - dens[i].mul_assign(tmp_2); // mul by domain size - } - - PairingsBn254.Fr[] memory partial_products = new PairingsBn254.Fr[](poly_nums.length); - partial_products[0].assign(PairingsBn254.new_fr(1)); - for (uint i = 1; i < dens.length; i++) { - partial_products[i].assign(dens[i-1]); - partial_products[i].mul_assign(partial_products[i-1]); - } - - tmp_2.assign(partial_products[partial_products.length - 1]); - tmp_2.mul_assign(dens[dens.length - 1]); - tmp_2 = tmp_2.inverse(); // tmp_2 contains a^-1 * b^-1 (with! the last one) - - for (uint i = dens.length; i > 0; i--) { - tmp_1.assign(tmp_2); // all inversed - tmp_1.mul_assign(partial_products[i-1]); // clear lowest terms - tmp_2.mul_assign(dens[i-1]); - dens[i-1].assign(tmp_1); - } - - for (uint i = 0; i < nums.length; i++) { - nums[i].mul_assign(dens[i]); - } - - return nums; - } - - // plonk paper verify process step8: Compute quotient polynomial evaluation - function verify_quotient_poly_eval_at_zeta( - PartialVerifierState memory state, - Proof memory proof, - VerificationKey memory vk - ) internal view returns (bool) { - PairingsBn254.Fr memory lhs = evaluate_vanishing(vk.domain_size, state.zeta); - require(lhs.value != 0); // we can not check a polynomial relationship if point z is in the domain - lhs.mul_assign(proof.quotient_polynomial_at_zeta); - - PairingsBn254.Fr memory quotient_challenge = PairingsBn254.new_fr(1); - PairingsBn254.Fr memory rhs = PairingsBn254.copy(proof.linearization_polynomial_at_zeta); - - // public inputs - PairingsBn254.Fr memory tmp = PairingsBn254.new_fr(0); - for (uint256 i = 0; i < proof.input_values.length; i++) { - tmp.assign(state.cached_lagrange_evals[i]); - tmp.mul_assign(PairingsBn254.new_fr(proof.input_values[i])); - rhs.add_assign(tmp); - } - - quotient_challenge.mul_assign(state.alpha); - - PairingsBn254.Fr memory z_part = PairingsBn254.copy(proof.grand_product_at_zeta_omega); - for (uint256 i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - tmp.assign(proof.permutation_polynomials_at_zeta[i]); - tmp.mul_assign(state.beta); - tmp.add_assign(state.gamma); - tmp.add_assign(proof.wire_values_at_zeta[i]); - - z_part.mul_assign(tmp); - } - - tmp.assign(state.gamma); - // we need a wire value of the last polynomial in enumeration - tmp.add_assign(proof.wire_values_at_zeta[STATE_WIDTH - 1]); - - z_part.mul_assign(tmp); - z_part.mul_assign(quotient_challenge); - - // NOTICE: this is different with plonk paper - // plonk paper should be: rhs.sub_assign(z_part); - rhs.add_assign(z_part); - - quotient_challenge.mul_assign(state.alpha); - - tmp.assign(state.cached_lagrange_evals[0]); - tmp.mul_assign(quotient_challenge); - - rhs.sub_assign(tmp); - - return lhs.value == rhs.value; - } - - function evaluate_vanishing( - uint256 domain_size, - PairingsBn254.Fr memory at - ) internal view returns (PairingsBn254.Fr memory res) { - res = at.pow(domain_size); - res.sub_assign(PairingsBn254.new_fr(1)); - } - - // This verifier is for a PLONK with a state width 3 - // and main gate equation - // q_a(X) * a(X) + - // q_b(X) * b(X) + - // q_c(X) * c(X) + - // q_m(X) * a(X) * b(X) + - // q_constants(X)+ - // where q_{}(X) are selectors a, b, c - state (witness) polynomials - - function verify(Proof memory proof, VerificationKey memory vk) internal view returns (bool) { - PartialVerifierState memory state; - - bool valid = verify_initial(state, proof, vk); - - if (valid == false) { - return false; - } - - valid = verify_commitments(state, proof, vk); - - return valid; - } -} - -contract KeyedPlonkVerifier is PlonkVerifier { - uint256 constant SERIALIZED_PROOF_LENGTH = 26; - using PairingsBn254 for PairingsBn254.Fr; - function get_verification_key() internal pure returns(VerificationKey memory vk) { - vk.domain_size = {{.Size}}; - vk.num_inputs = {{.NbPublicVariables}}; - vk.omega = PairingsBn254.new_fr(uint256({{.Generator.String}})); - vk.selector_commitments[0] = PairingsBn254.new_g1( - uint256({{.Ql.X.String}}), - uint256({{.Ql.Y.String}}) - ); - vk.selector_commitments[1] = PairingsBn254.new_g1( - uint256({{.Qr.X.String}}), - uint256({{.Qr.Y.String}}) - ); - vk.selector_commitments[2] = PairingsBn254.new_g1( - uint256({{.Qo.X.String}}), - uint256({{.Qo.Y.String}}) - ); - vk.selector_commitments[3] = PairingsBn254.new_g1( - uint256({{.Qm.X.String}}), - uint256({{.Qm.Y.String}}) - ); - vk.selector_commitments[4] = PairingsBn254.new_g1( - uint256({{.Qk.X.String}}), - uint256({{.Qk.Y.String}}) - ); - - vk.permutation_commitments[0] = PairingsBn254.new_g1( - uint256({{(index .S 0).X.String}}), - uint256({{(index .S 0).Y.String}}) - ); - vk.permutation_commitments[1] = PairingsBn254.new_g1( - uint256({{(index .S 1).X.String}}), - uint256({{(index .S 1).Y.String}}) - ); - vk.permutation_commitments[2] = PairingsBn254.new_g1( - uint256({{(index .S 2).X.String}}), - uint256({{(index .S 2).Y.String}}) - ); - - vk.permutation_non_residues[0] = PairingsBn254.new_fr( - uint256({{.CosetShift.String}}) - ); - vk.permutation_non_residues[1] = PairingsBn254.copy( - vk.permutation_non_residues[0] - ); - vk.permutation_non_residues[1].mul_assign(vk.permutation_non_residues[0]); - - vk.g2_x = PairingsBn254.new_g2( - [uint256({{(index .KZGSRS.G2 1).X.A1.String}}), - uint256({{(index .KZGSRS.G2 1).X.A0.String}})], - [uint256({{(index .KZGSRS.G2 1).Y.A1.String}}), - uint256({{(index .KZGSRS.G2 1).Y.A0.String}})] - ); - } - - - function deserialize_proof( - uint256[] memory public_inputs, - uint256[] memory serialized_proof - ) internal pure returns(Proof memory proof) { - require(serialized_proof.length == SERIALIZED_PROOF_LENGTH); - proof.input_values = new uint256[](public_inputs.length); - for (uint256 i = 0; i < public_inputs.length; i++) { - proof.input_values[i] = public_inputs[i]; - } - - uint256 j = 0; - for (uint256 i = 0; i < STATE_WIDTH; i++) { - proof.wire_commitments[i] = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - - j += 2; - } - - proof.grand_product_commitment = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - j += 2; - - for (uint256 i = 0; i < STATE_WIDTH; i++) { - proof.quotient_poly_commitments[i] = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - - j += 2; - } - - for (uint256 i = 0; i < STATE_WIDTH; i++) { - proof.wire_values_at_zeta[i] = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - } - - proof.grand_product_at_zeta_omega = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - - proof.quotient_polynomial_at_zeta = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - - proof.linearization_polynomial_at_zeta = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - - for (uint256 i = 0; i < proof.permutation_polynomials_at_zeta.length; i++) { - proof.permutation_polynomials_at_zeta[i] = PairingsBn254.new_fr( - serialized_proof[j] - ); - - j += 1; - } - - proof.opening_at_zeta_proof = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - j += 2; - - proof.opening_at_zeta_omega_proof = PairingsBn254.new_g1_checked( - serialized_proof[j], - serialized_proof[j+1] - ); - } - - function verify_serialized_proof( - uint256[] memory public_inputs, - uint256[] memory serialized_proof - ) public view returns (bool) { - VerificationKey memory vk = get_verification_key(); - require(vk.num_inputs == public_inputs.length); - Proof memory proof = deserialize_proof(public_inputs, serialized_proof); - bool valid = verify(proof, vk); - return valid; - } -} -` diff --git a/internal/backend/bw6-633/plonk/marshal.go b/internal/backend/bw6-633/plonk/marshal.go deleted file mode 100644 index 6ae656e777..0000000000 --- a/internal/backend/bw6-633/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bw6-633/plonk/marshal_test.go b/internal/backend/bw6-633/plonk/marshal_test.go deleted file mode 100644 index 09064a8900..0000000000 --- a/internal/backend/bw6-633/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go deleted file mode 100644 index bd6f3ff24b..0000000000 --- a/internal/backend/bw6-633/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" - "github.com/consensys/gnark/constraint/bw6-633" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go deleted file mode 100644 index 4e6d650887..0000000000 --- a/internal/backend/bw6-633/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" - "github.com/consensys/gnark/constraint/bw6-633" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/bw6-761/plonk/marshal.go b/internal/backend/bw6-761/plonk/marshal.go deleted file mode 100644 index d51e9ef211..0000000000 --- a/internal/backend/bw6-761/plonk/marshal.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "io" -) - -// WriteTo writes binary encoding of Proof to w without point compression -func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { - return proof.writeTo(w, curve.RawEncoding()) -} - -// WriteTo writes binary encoding of Proof to w with point compression -func (proof *Proof) WriteTo(w io.Writer) (int64, error) { - return proof.writeTo(w) -} - -func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64, error) { - enc := curve.NewEncoder(w, options...) - - toEncode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads binary representation of Proof from r -func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &proof.LRO[0], - &proof.LRO[1], - &proof.LRO[2], - &proof.Z, - &proof.H[0], - &proof.H[1], - &proof.H[2], - &proof.BatchedProof.H, - &proof.BatchedProof.ClaimedValues, - &proof.ZShiftedOpening.H, - &proof.ZShiftedOpening.ClaimedValue, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} - -// WriteTo writes binary encoding of ProvingKey to w -func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { - // encode the verifying key - n, err = pk.Vk.WriteTo(w) - if err != nil { - return - } - - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into ProvingKey -func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { - pk.Vk = &VerifyingKey{} - n, err := pk.Vk.ReadFrom(r) - if err != nil { - return n, err - } - - n2, err := pk.Domain[0].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err = pk.Domain[1].ReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err - } - } - - pk.computeLagrangeCosetPolys() - - return n + dec.BytesRead(), nil - -} - -// WriteTo writes binary encoding of VerifyingKey to w -func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - enc := curve.NewEncoder(w) - - toEncode := []interface{}{ - vk.Size, - &vk.SizeInv, - &vk.Generator, - vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return enc.BytesWritten(), err - } - } - - return enc.BytesWritten(), nil -} - -// ReadFrom reads from binary representation in r into VerifyingKey -func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - dec := curve.NewDecoder(r) - toDecode := []interface{}{ - &vk.Size, - &vk.SizeInv, - &vk.Generator, - &vk.NbPublicVariables, - &vk.CosetShift, - &vk.S[0], - &vk.S[1], - &vk.S[2], - &vk.Ql, - &vk.Qr, - &vk.Qm, - &vk.Qo, - &vk.Qk, - } - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err - } - } - - return dec.BytesRead(), nil -} diff --git a/internal/backend/bw6-761/plonk/marshal_test.go b/internal/backend/bw6-761/plonk/marshal_test.go deleted file mode 100644 index 548044159d..0000000000 --- a/internal/backend/bw6-761/plonk/marshal_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - "bytes" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - gnarkio "github.com/consensys/gnark/io" - "io" - "math/big" - "math/rand" - "reflect" - "testing" -) - -func TestProofSerialization(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheck(t, &proof, &reconstructed) -} - -func TestProofSerializationRaw(t *testing.T) { - // create a proof - var proof, reconstructed Proof - proof.randomize() - - roundTripCheckRaw(t, &proof, &reconstructed) -} - -func TestProvingKeySerialization(t *testing.T) { - // random pk - var pk, reconstructed ProvingKey - pk.randomize() - - roundTripCheck(t, &pk, &reconstructed) -} - -func TestVerifyingKeySerialization(t *testing.T) { - // create a random vk - var vk, reconstructed VerifyingKey - vk.randomize() - - roundTripCheck(t, &vk, &reconstructed) -} - -func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { - var buf bytes.Buffer - written, err := from.WriteRawTo(&buf) - if err != nil { - t.Fatal("couldn't serialize", err) - } - - read, err := reconstructed.ReadFrom(&buf) - if err != nil { - t.Fatal("couldn't deserialize", err) - } - - if !reflect.DeepEqual(from, reconstructed) { - t.Fatal("reconstructed object don't match original") - } - - if written != read { - t.Fatal("bytes written / read don't match") - } -} - -func (pk *ProvingKey) randomize() { - var vk VerifyingKey - vk.randomize() - pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(42) - pk.Domain[1] = *fft.NewDomain(4 * 42) - - n := int(pk.Domain[0].Cardinality) - pk.Ql = randomScalars(n) - pk.Qr = randomScalars(n) - pk.Qm = randomScalars(n) - pk.Qo = randomScalars(n) - pk.CQk = randomScalars(n) - pk.LQk = randomScalars(n) - pk.S1Canonical = randomScalars(n) - pk.S2Canonical = randomScalars(n) - pk.S3Canonical = randomScalars(n) - - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) - pk.Permutation[0] = -12 - pk.Permutation[len(pk.Permutation)-1] = 8888 - - pk.computeLagrangeCosetPolys() -} - -func (vk *VerifyingKey) randomize() { - vk.Size = rand.Uint64() - vk.SizeInv.SetRandom() - vk.Generator.SetRandom() - vk.NbPublicVariables = rand.Uint64() - vk.CosetShift.SetRandom() - - vk.S[0] = randomPoint() - vk.S[1] = randomPoint() - vk.S[2] = randomPoint() - vk.Ql = randomPoint() - vk.Qr = randomPoint() - vk.Qm = randomPoint() - vk.Qo = randomPoint() - vk.Qk = randomPoint() -} - -func (proof *Proof) randomize() { - proof.LRO[0] = randomPoint() - proof.LRO[1] = randomPoint() - proof.LRO[2] = randomPoint() - proof.Z = randomPoint() - proof.H[0] = randomPoint() - proof.H[1] = randomPoint() - proof.H[2] = randomPoint() - proof.BatchedProof.H = randomPoint() - proof.BatchedProof.ClaimedValues = randomScalars(2) - proof.ZShiftedOpening.H = randomPoint() - proof.ZShiftedOpening.ClaimedValue.SetRandom() -} - -func randomPoint() curve.G1Affine { - _, _, r, _ := curve.Generators() - r.ScalarMultiplication(&r, big.NewInt(int64(rand.Uint64()))) - return r -} - -func randomScalars(n int) []fr.Element { - v := make([]fr.Element, n) - one := fr.One() - for i := 0; i < len(v); i++ { - if i == 0 { - v[i].SetRandom() - } else { - v[i].Add(&v[i-1], &one) - } - } - return v -} diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go deleted file mode 100644 index 5a372eb87b..0000000000 --- a/internal/backend/bw6-761/plonk/prove.go +++ /dev/null @@ -1,634 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "crypto/sha256" - "math/big" - "runtime" - "sync" - "time" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" - "github.com/consensys/gnark/constraint/bw6-761" - - "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/logger" -) - -type Proof struct { - - // Commitments to the solution vectors - LRO [3]kzg.Digest - - // Commitment to Z, the permutation polynomial - Z kzg.Digest - - // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial - H [3]kzg.Digest - - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 - BatchedProof kzg.BatchOpeningProof - - // Opening proof of Z at zeta*mu - ZShiftedOpening kzg.OpeningProof -} - -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() - start := time.Now() - // pick a hash function that will be used to derive the challenges - hFunc := sha256.New() - - // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") - - // result - proof := &Proof{} - - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } - } - - // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) - - lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall, lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall, lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall, lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // The first challenge is derived using the public data: the commitments to the permutation, - // the coefficients of the circuit, and the public inputs. - // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { - return nil, err - } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) - if err != nil { - return nil, err - } - - // Fiat Shamir this - bbeta, err := fs.ComputeChallenge("beta") - if err != nil { - return nil, err - } - var beta fr.Element - beta.SetBytes(bbeta) - - // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... - ziop, err := iop.BuildRatioCopyConstraint( - []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), - }, - pk.Permutation, - beta, - gamma, - iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, - &pk.Domain[0], - ) - if err != nil { - return proof, err - } - - // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } - - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } - - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) - - // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { - - var ic, tmp fr.Element - - ic.Mul(&fql, &l) - tmp.Mul(&fqr, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqm, &l).Mul(&tmp, &r) - ic.Add(&ic, &tmp) - tmp.Mul(&fqo, &o) - ic.Add(&ic, &tmp).Add(&ic, &fqk) - - return ic - } - - fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - - var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) - a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) - a.Mul(&a, &tmp).Mul(&a, &fz) - - b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) - tmp.Mul(&beta, &fs2).Add(&tmp, &r).Add(&tmp, &gamma) - b.Mul(&b, &tmp) - tmp.Mul(&beta, &fs3).Add(&tmp, &o).Add(&tmp, &gamma) - b.Mul(&b, &tmp).Mul(&b, &fzs) - - b.Sub(&b, &a) - - return b - } - - fone := func(fz, flone fr.Element) fr.Element { - one := fr.One() - one.Sub(&fz, &one).Mul(&one, &flone) - return one - } - - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone - fm := func(x ...fr.Element) fr.Element { - - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) - b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) - c := fone(x[7], x[14]) - - c.Mul(&c, &alpha).Add(&c, &b).Mul(&c, &alpha).Add(&c, &a) - - return c - } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, - bwliop, - bwriop, - bwoiop, - widiop, - ws1, - ws2, - ws3, - bwziop, - bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) - if err != nil { - return nil, err - } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) - if err != nil { - return nil, err - } - - // compute kzg commitments of h1, h2 and h3 - if err := commitToQuotient( - h.Coefficients()[:pk.Domain[0].Cardinality+2], - h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], - h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { - return nil, err - } - - // derive zeta - zeta, err := deriveRandomness(&fs, "zeta", &proof.H[0], &proof.H[1], &proof.H[2]) - if err != nil { - return nil, err - } - - // compute evaluations of (blinded version of) l, r, o, z at zeta - var blzeta, brzeta, bozeta fr.Element - - var wgEvals sync.WaitGroup - wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() - var zetaShifted fr.Element - zetaShifted.Mul(&zeta, &pk.Vk.Generator) - proof.ZShiftedOpening, err = kzg.Open( - bwziop.Coefficients()[:bwziop.BlindedSize()], - zetaShifted, - pk.Vk.KZGSRS, - ) - if err != nil { - return nil, err - } - - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue - - var ( - linearizedPolynomialCanonical []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error - ) - - wgEvals.Wait() // wait for the evaluations - - // compute the linearization polynomial r at zeta - // (goal: save committing separately to z, ql, qr, qm, qo, k - linearizedPolynomialCanonical = computeLinearizedPolynomial( - blzeta, - brzeta, - bozeta, - alpha, - beta, - gamma, - zeta, - bzuzeta, - bwziop.Coefficients()[:bwziop.BlindedSize()], - pk, - ) - - // TODO this commitment is only necessary to derive the challenge, we should - // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - - if errLPoly != nil { - return nil, errLPoly - } - - // Batch open the first list of polynomials - proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, - zeta, - hFunc, - pk.Vk.KZGSRS, - ) - - log.Debug().Dur("took", time.Since(start)).Msg("prover done") - - if err != nil { - return nil, err - } - - return proof, nil - -} - -// fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) - close(chCommit0) - }() - go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) - close(chCommit1) - }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 - var err0, err1, err2 error - chCommit0 := make(chan struct{}, 1) - chCommit1 := make(chan struct{}, 1) - go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) - close(chCommit0) - }() - go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) - close(chCommit1) - }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { - return err2 - } - <-chCommit0 - <-chCommit1 - - if err0 != nil { - return err0 - } - - return err1 -} - -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. -// The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z -// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -// -// The Linearized polynomial is: -// -// α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) -// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { - - // first part: individual constraints - var rl fr.Element - rl.Mul(&rZeta, &lZeta) - - // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) - var s1, s2 fr.Element - chS1 := make(chan struct{}, 1) - go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) - s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) - close(chS1) - }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) - tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) - <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) - - var uzeta, uuzeta fr.Element - uzeta.Mul(&zeta, &pk.Vk.CosetShift) - uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) - - s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) - tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) - s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - // third part L₁(ζ)*α²*Z - var lagrangeZeta, one, den, frNbElmt fr.Element - one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) - lagrangeZeta.Set(&zeta). - Exp(lagrangeZeta, big.NewInt(nbElmt)). - Sub(&lagrangeZeta, &one) - frNbElmt.SetUint64(uint64(nbElmt)) - den.Sub(&zeta, &one). - Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { - - var t0, t1 fr.Element - - for i := start; i < end; i++ { - - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - - if i < len(pk.S3Canonical) { - - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - - linPol[i].Add(&linPol[i], &t0) - } - - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - - if i < len(pk.Qm) { - - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) - t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - } - - t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation - } - }) - return linPol -} diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go deleted file mode 100644 index 096aba6f95..0000000000 --- a/internal/backend/bw6-761/plonk/setup.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package plonk - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" - "github.com/consensys/gnark/constraint/bw6-761" - - kzgg "github.com/consensys/gnark-crypto/kzg" -) - -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 -} - -// VerifyingKey stores the data needed to verify a proof: -// * The commitment scheme -// * Commitments of ql prepended with as many ones as there are public inputs -// * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs -// * Commitments to S1, S2, S3 -type VerifyingKey struct { - // Size circuit - Size uint64 - SizeInv fr.Element - Generator fr.Element - NbPublicVariables uint64 - - // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS - - // cosetShift generator of the coset on the small domain - CosetShift fr.Element - - // S commitments to S1, S2, S3 - S [3]kzg.Digest - - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. - // In particular Qk is not complete. - Ql, Qr, Qm, Qo, Qk kzg.Digest -} - -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { - var pk ProvingKey - var vk VerifyingKey - - // The verifying key shares data with the proving key - pk.Vk = &vk - - nbConstraints := len(spr.Constraints) - - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - - vk.Size = pk.Domain[0].Cardinality - vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) - vk.NbPublicVariables = uint64(len(spr.Public)) - - if err := pk.InitKZG(srs); err != nil { - return nil, nil, err - } - - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover - } - offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - } - - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) - - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) - - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) - - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() - - // Commit to the polynomials to set up the verifying key - var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err - } - - return &pk, &vk, nil - -} - -// buildPermutation builds the Permutation associated with a circuit. -// -// The permutation s is composed of cycles of maximum length such that -// -// s. (l∥r∥o) = (l∥r∥o) -// -// , where l∥r∥o is the concatenation of the indices of l, r, o in -// ql.l+qr.r+qm.l.r+qo.O+k = 0. -// -// The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab -// like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { - - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) - - // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 - } - - // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID - for i := 0; i < len(spr.Public); i++ { - lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) - } - - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() - } - - // init cycle: - // map ID -> last position the ID was seen - cycle := make([]int64, nbVariables) - for i := 0; i < len(cycle); i++ { - cycle[i] = -1 - } - - for i := 0; i < len(lro); i++ { - if cycle[lro[i]] != -1 { - // if != -1, it means we already encountered this value - // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] - } - cycle[lro[i]] = int64(i) - } - - // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] - } - } -} - -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() -} - -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} - -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// -// | -// | Permutation -// -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { - - nbElmts := int(pk.Domain[0].Cardinality) - - // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) - - // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) - for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) - } - - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) -} - -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { - - res := make([]fr.Element, 3*domain.Cardinality) - - res[0].SetOne() - res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) - res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) - - for i := uint64(1); i < domain.Cardinality; i++ { - res[i].Mul(&res[i-1], &domain.Generator) - res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) - res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) - } - - return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} diff --git a/internal/backend/circuits/circuits.go b/internal/backend/circuits/circuits.go index 51bdb7b9a9..92b306de73 100644 --- a/internal/backend/circuits/circuits.go +++ b/internal/backend/circuits/circuits.go @@ -3,7 +3,8 @@ package circuits import ( "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -11,7 +12,7 @@ import ( type TestCircuit struct { Circuit frontend.Circuit ValidAssignments, InvalidAssignments []frontend.Circuit // good and bad witness for the prover + public verifier data - HintFunctions []hint.Function + HintFunctions []solver.Hint Curves []ecc.ID } @@ -30,13 +31,14 @@ func addEntry(name string, circuit, proverGood, proverBad frontend.Circuit, curv Circuits[name] = TestCircuit{circuit, []frontend.Circuit{proverGood}, []frontend.Circuit{proverBad}, nil, curves} } -func addNewEntry(name string, circuit frontend.Circuit, proverGood, proverBad []frontend.Circuit, curves []ecc.ID, hintFunctions ...hint.Function) { +func addNewEntry(name string, circuit frontend.Circuit, proverGood, proverBad []frontend.Circuit, curves []ecc.ID, hintFunctions ...solver.Hint) { if Circuits == nil { Circuits = make(map[string]TestCircuit) } if _, ok := Circuits[name]; ok { panic("name " + name + "already taken by another test circuit ") } + solver.RegisterHint(hintFunctions...) Circuits[name] = TestCircuit{circuit, proverGood, proverBad, hintFunctions, curves} } diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index 5ed7a009de..1542438d5f 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -91,7 +91,7 @@ func init() { }, } - addNewEntry("recursive_hint", &recursiveHint{}, good, bad, gnark.Curves(), make3, bits.NBits) + addNewEntry("recursive_hint", &recursiveHint{}, good, bad, gnark.Curves(), make3, bits.GetHints()[1]) } { diff --git a/internal/circuitdefer/defer.go b/internal/circuitdefer/defer.go new file mode 100644 index 0000000000..4bda7383e7 --- /dev/null +++ b/internal/circuitdefer/defer.go @@ -0,0 +1,43 @@ +package circuitdefer + +import ( + "github.com/consensys/gnark/internal/kvstore" +) + +type deferKey struct{} + +func Put[T any](builder any, cb T) { + // we use generics for type safety but to avoid import cycles. + // TODO: compare with using any and type asserting at caller + kv, ok := builder.(kvstore.Store) + if !ok { + panic("builder does not implement kvstore.Store") + } + val := kv.GetKeyValue(deferKey{}) + var deferred []T + if val != nil { + var ok bool + deferred, ok = val.([]T) + if !ok { + panic("stored deferred functions not []func(frontend.API) error") + } + } + deferred = append(deferred, cb) + kv.SetKeyValue(deferKey{}, deferred) +} + +func GetAll[T any](builder any) []T { + kv, ok := builder.(kvstore.Store) + if !ok { + panic("builder does not implement kvstore.Store") + } + val := kv.GetKeyValue(deferKey{}) + if val == nil { + return nil + } + deferred, ok := val.([]T) + if !ok { + panic("stored deferred functions not []func(frontend.API) error") + } + return deferred +} diff --git a/internal/frontendtype/frontendtype.go b/internal/frontendtype/frontendtype.go new file mode 100644 index 0000000000..040af53c5b --- /dev/null +++ b/internal/frontendtype/frontendtype.go @@ -0,0 +1,13 @@ +// Package frontendtype allows to assert frontend type. +package frontendtype + +type Type int + +const ( + R1CS Type = iota + SCS +) + +type FrontendTyper interface { + FrontendType() Type +} diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index df801f8650..b5c0bfe1a8 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -4,6 +4,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "github.com/consensys/bavard" @@ -19,43 +20,43 @@ var bgen = bavard.NewBatchGenerator(copyrightHolder, 2020, "gnark") func main() { bls12_377 := templateData{ - RootPath: "../../../internal/backend/bls12-377/", + RootPath: "../../../backend/{?}/bls12-377/", CSPath: "../../../constraint/bls12-377/", Curve: "BLS12-377", CurveID: "BLS12_377", } bls12_381 := templateData{ - RootPath: "../../../internal/backend/bls12-381/", + RootPath: "../../../backend/{?}/bls12-381/", CSPath: "../../../constraint/bls12-381/", Curve: "BLS12-381", CurveID: "BLS12_381", } bn254 := templateData{ - RootPath: "../../../internal/backend/bn254/", + RootPath: "../../../backend/{?}/bn254/", CSPath: "../../../constraint/bn254/", Curve: "BN254", CurveID: "BN254", } bw6_761 := templateData{ - RootPath: "../../../internal/backend/bw6-761/", + RootPath: "../../../backend/{?}/bw6-761/", CSPath: "../../../constraint/bw6-761/", Curve: "BW6-761", CurveID: "BW6_761", } bls24_315 := templateData{ - RootPath: "../../../internal/backend/bls24-315/", + RootPath: "../../../backend/{?}/bls24-315/", CSPath: "../../../constraint/bls24-315/", Curve: "BLS24-315", CurveID: "BLS24_315", } bls24_317 := templateData{ - RootPath: "../../../internal/backend/bls24-317/", + RootPath: "../../../backend/{?}/bls24-317/", CSPath: "../../../constraint/bls24-317/", Curve: "BLS24-317", CurveID: "BLS24_317", } bw6_633 := templateData{ - RootPath: "../../../internal/backend/bw6-633/", + RootPath: "../../../backend/{?}/bw6-633/", CSPath: "../../../constraint/bw6-633/", Curve: "BW6-633", CurveID: "BW6_633", @@ -97,16 +98,22 @@ func main() { wg.Add(1) go func(d templateData) { - defer wg.Done() - if err := os.MkdirAll(d.RootPath+"groth16", 0700); err != nil { + var ( + groth16Dir = strings.Replace(d.RootPath, "{?}", "groth16", 1) + groth16MpcSetupDir = filepath.Join(groth16Dir, "mpcsetup") + plonkDir = strings.Replace(d.RootPath, "{?}", "plonk", 1) + plonkFriDir = strings.Replace(d.RootPath, "{?}", "plonkfri", 1) + ) + + if err := os.MkdirAll(groth16Dir, 0700); err != nil { panic(err) } - if err := os.MkdirAll(d.RootPath+"plonk", 0700); err != nil { + if err := os.MkdirAll(plonkDir, 0700); err != nil { panic(err) } - if err := os.MkdirAll(d.RootPath+"plonkfri", 0700); err != nil { + if err := os.MkdirAll(plonkFriDir, 0700); err != nil { panic(err) } @@ -114,15 +121,22 @@ func main() { // constraint systems entries := []bavard.Entry{ - {File: filepath.Join(csDir, "r1cs.go"), Templates: []string{"r1cs.go.tmpl", importCurve}}, + {File: filepath.Join(csDir, "system.go"), Templates: []string{"system.go.tmpl", importCurve}}, {File: filepath.Join(csDir, "coeff.go"), Templates: []string{"coeff.go.tmpl", importCurve}}, - {File: filepath.Join(csDir, "r1cs_sparse.go"), Templates: []string{"r1cs.sparse.go.tmpl", importCurve}}, - {File: filepath.Join(csDir, "solution.go"), Templates: []string{"solution.go.tmpl", importCurve}}, + {File: filepath.Join(csDir, "solver.go"), Templates: []string{"solver.go.tmpl", importCurve}}, } if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { panic(err) } + // gkr backend + if d.Curve != "tinyfield" { + entries = []bavard.Entry{{File: filepath.Join(csDir, "gkr.go"), Templates: []string{"gkr.go.tmpl", importCurve}}} + if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { + panic(err) + } + } + entries = []bavard.Entry{ {File: filepath.Join(csDir, "r1cs_test.go"), Templates: []string{"tests/r1cs.go.tmpl", importCurve}}, } @@ -136,10 +150,6 @@ func main() { return } - plonkFriDir := filepath.Join(d.RootPath, "plonkfri") - groth16Dir := filepath.Join(d.RootPath, "groth16") - plonkDir := filepath.Join(d.RootPath, "plonk") - if err := os.MkdirAll(groth16Dir, 0700); err != nil { panic(err) } @@ -166,6 +176,22 @@ func main() { panic(err) // TODO handle } + // groth16 mpcsetup + entries = []bavard.Entry{ + {File: filepath.Join(groth16MpcSetupDir, "lagrange.go"), Templates: []string{"groth16/mpcsetup/lagrange.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "marshal.go"), Templates: []string{"groth16/mpcsetup/marshal.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "marshal_test.go"), Templates: []string{"groth16/mpcsetup/marshal_test.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "phase1.go"), Templates: []string{"groth16/mpcsetup/phase1.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "phase2.go"), Templates: []string{"groth16/mpcsetup/phase2.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "setup.go"), Templates: []string{"groth16/mpcsetup/setup.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "setup_test.go"), Templates: []string{"groth16/mpcsetup/setup_test.go.tmpl", importCurve}}, + {File: filepath.Join(groth16MpcSetupDir, "utils.go"), Templates: []string{"groth16/mpcsetup/utils.go.tmpl", importCurve}}, + } + + if err := bgen.Generate(d, "mpcsetup", "./template/zkpschemes/", entries...); err != nil { + panic(err) // TODO handle + } + // plonk entries = []bavard.Entry{ {File: filepath.Join(plonkDir, "verify.go"), Templates: []string{"plonk/plonk.verify.go.tmpl", importCurve}}, diff --git a/internal/generator/backend/template/imports.go.tmpl b/internal/generator/backend/template/imports.go.tmpl index 56639bdfb2..949969f9d7 100644 --- a/internal/generator/backend/template/imports.go.tmpl +++ b/internal/generator/backend/template/imports.go.tmpl @@ -22,7 +22,7 @@ {{- if eq .Curve "tinyfield"}} "github.com/consensys/gnark/constraint/tinyfield" {{- else}} - "github.com/consensys/gnark/constraint/{{toLower .Curve}}" + cs "github.com/consensys/gnark/constraint/{{toLower .Curve}}" {{- end}} {{- end }} diff --git a/internal/generator/backend/template/representations/coeff.go.tmpl b/internal/generator/backend/template/representations/coeff.go.tmpl index 9ebe15314e..3ba357d1e6 100644 --- a/internal/generator/backend/template/representations/coeff.go.tmpl +++ b/internal/generator/backend/template/representations/coeff.go.tmpl @@ -1,7 +1,7 @@ import ( "github.com/consensys/gnark/constraint" - "math/big" "github.com/consensys/gnark/internal/utils" + "math/big" {{ template "import_fr" . }} ) @@ -27,8 +27,7 @@ func newCoeffTable(capacity int) CoeffTable { } - -func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constraint.Term { +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { c := (*fr.Element)(coeff[:]) var cID uint32 if c.IsZero() { @@ -51,7 +50,11 @@ func (ct *CoeffTable) MakeTerm(coeff *constraint.Coeff, variableID int) constrai ct.mCoeffs[cc] = cID } } - + return cID +} + +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) return constraint.Term{VID: uint32(variableID), CID: cID} } @@ -60,8 +63,10 @@ func (ct *CoeffTable) CoeffToString(cID int) string { return ct.Coefficients[cID].String() } +// implements constraint.Field +type field struct{} -var _ constraint.CoeffEngine = &arithEngine{} +var _ constraint.Field = &field{} var ( two fr.Element @@ -78,11 +83,9 @@ func init() { } -// implements constraint.CoeffEngine -type arithEngine struct{} -func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { +func (engine *field) FromInterface(i interface{}) constraint.Element { var e fr.Element if _, err := e.SetInterface(i); err != nil { // need to clean that --> some code path are dissimilar @@ -91,55 +94,83 @@ func (engine *arithEngine) FromInterface(i interface{}) constraint.Coeff { b := utils.FromInterface(i) e.SetBigInt(&b) } - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) ToBigInt(c *constraint.Coeff) *big.Int { +func (engine *field) ToBigInt(c constraint.Element) *big.Int { e := (*fr.Element)(c[:]) r := new(big.Int) e.BigInt(r) return r } -func (engine *arithEngine) Mul(a, b *constraint.Coeff) { +func (engine *field) Mul(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Mul(_a, _b) + return a } -func (engine *arithEngine) Add(a, b *constraint.Coeff) { + +func (engine *field) Add(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Add(_a, _b) + return a } -func (engine *arithEngine) Sub(a, b *constraint.Coeff) { +func (engine *field) Sub(a, b constraint.Element) constraint.Element { _a := (*fr.Element)(a[:]) _b := (*fr.Element)(b[:]) _a.Sub(_a, _b) + return a } -func (engine *arithEngine) Neg(a *constraint.Coeff) { +func (engine *field) Neg(a constraint.Element) constraint.Element { e := (*fr.Element)(a[:]) e.Neg(e) + return a } -func (engine *arithEngine) Inverse(a *constraint.Coeff) { +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + e.Inverse(e) + return a, true } -func (engine *arithEngine) IsOne(a *constraint.Coeff) bool { +func (engine *field) IsOne(a constraint.Element) bool { e := (*fr.Element)(a[:]) return e.IsOne() } -func (engine *arithEngine) One() constraint.Coeff { +func (engine *field) One() constraint.Element { e := fr.One() - var r constraint.Coeff + var r constraint.Element copy(r[:], e[:]) return r } -func (engine *arithEngine) String(a *constraint.Coeff) string { +func (engine *field) String(a constraint.Element) string { e := (*fr.Element)(a[:]) return e.String() +} + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true } \ No newline at end of file diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl new file mode 100644 index 0000000000..110276de60 --- /dev/null +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -0,0 +1,178 @@ +import ( + "fmt" + {{- template "import_fr" .}} + {{- template "import_gkr" .}} + {{- template "import_polynomial" .}} + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { + resCircuit := make(gkr.Circuit, len(noPtr)) + var found bool + for i := range noPtr { + if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) + } + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit, nil +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { + if d.circuit, err = convertCircuit(info.Circuit); err != nil { + return + } + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignment = make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignment { + assignment[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignment[i] + } + return +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment, err := solvingData.init(info) + if err != nil { + return err + } + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := make([]byte, fr.Bytes) + i.FillBytes(b) + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) \ No newline at end of file diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl deleted file mode 100644 index d65b171737..0000000000 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ /dev/null @@ -1,448 +0,0 @@ -import ( - "errors" - "fmt" - "io" - "runtime" - "time" - "sync" - "github.com/fxamacker/cbor/v2" - - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - "github.com/consensys/gnark/backend/witness" - - "math" - "github.com/consensys/gnark-crypto/ecc" - - {{ template "import_fr" . }} -) - -// R1CS describes a set of R1CS constraint -type R1CS struct { - constraint.R1CSCore - CoeffTable - arithEngine -} - -// NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values -// -// capacity pre-allocates memory for capacity nbConstraints -func NewR1CS(capacity int) *R1CS { - r := R1CS{ - R1CSCore: constraint.R1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.R1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - return &r -} - - -func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, r1c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - - cs.UpdateLevel(cID, &r1c) - - return cID -} - - - -// Solve sets all the wires and returns the a, b, c vectors. -// the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. -// a, b, c vectors: ab-c = hz -// witness = [publicWires | secretWires] (without the ONE_WIRE !) -// returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() - - nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables - solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return make(fr.Vector, nbWires), err - } - start := time.Now() - - if len(witness) != len(cs.Public)-1+len(cs.Secret) { // - 1 for ONE_WIRE - err = fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), int(len(cs.Public)-1+len(cs.Secret)), len(cs.Public)-1, len(cs.Secret)) - log.Err(err).Send() - return solution.values, err - } - - // compute the wires and the a, b, c polynomials - if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { - err = errors.New("invalid input size: len(a, b, c) == len(Constraints)") - log.Err(err).Send() - return solution.values, err - } - - solution.solved[0] = true // ONE_WIRE - solution.values[0].SetOne() - copy(solution.values[1:], witness) - for i := range witness { - solution.solved[i+1] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness) + 1) - - // now that we know all inputs are set, defer log printing once all solution.values are computed - // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - if err := cs.parallelSolve(a, b, c, &solution); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil -} - - - -func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - chError <- &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - var debugInfo *string - if dID, ok := cs.MDebug[i]; ok { - debugInfo = new(string) - *debugInfo = solution.logValue(cs.DebugInfo[dID]) - } - return &UnsatisfiedConstraintError{CID: i, Err: err, DebugInfo: debugInfo} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - -// IsSolved returns nil if given witness solves the R1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - a := make(fr.Vector, len(cs.Constraints)) - b := make(fr.Vector, len(cs.Constraints)) - c := make(fr.Vector, len(cs.Constraints)) - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, a, b, c, opt) - return err -} - -// divByCoeff sets res = res / t.Coeff -func (cs *R1CS) divByCoeff(res *fr.Element, t constraint.Term) { - cID := t.CoeffID() - switch cID { - case constraint.CoeffIdOne: - return - case constraint.CoeffIdMinusOne: - res.Neg(res) - case constraint.CoeffIdZero: - panic("division by 0") - default: - // this is slow, but shouldn't happen as divByCoeff is called to - // remove the coeff of an unsolved wire - // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 - res.Div(res, &cs.Coefficients[cID]) - } -} - - - -// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly -// -// returns an error if the solver called a hint function that errored -// returns false, nil if there was no wire to solve -// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that -// the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r constraint.R1C, solution *solution, a,b,c *fr.Element) error { - - // the index of the non-zero entry shows if L, R or O has an uninstantiated wire - // the content is the ID of the wire non instantiated - var loc uint8 - - var termToCompute constraint.Term - - processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) error { - for _, t := range l { - vID := t.WireID() - - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - solution.accumulateInto(t, val) - continue - } - - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err - } - // now that the wire is saved, accumulate it into a, b or c - solution.accumulateInto(t, val) - continue - } - - if loc != 0 { - panic("found more than one wire to instantiate") - } - termToCompute = t - loc = locValue - } - return nil - } - - if err := processLExp(r.L, a, 1); err != nil { - return err - } - - if err := processLExp(r.R, b, 2); err != nil { - return err - } - - if err := processLExp(r.O, c, 3); err != nil { - return err - } - - if loc == 0 { - // there is nothing to solve, may happen if we have an assertion - // (ie a constraints that doesn't yield any output) - // or if we solved the unsolved wires with hint functions - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - return nil - } - - // we compute the wire value and instantiate it - wID := termToCompute.WireID() - - // solver result - var wire fr.Element - - - switch loc { - case 1: - if !b.IsZero() { - wire.Div(c, b). - Sub(&wire, a) - a.Add(a, &wire) - } else { - // we didn't actually ensure that a * b == c - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 2: - if !a.IsZero() { - wire.Div(c, a). - Sub(&wire, b) - b.Add(b, &wire) - } else { - var check fr.Element - if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) - } - } - case 3: - wire.Mul(a, b). - Sub(&wire, c) - - c.Add(c, &wire) - } - - // wire is the term (coeff * value) - // but in the solution we want to store the value only - // note that in gnark frontend, coeff here is always 1 or -1 - cs.divByCoeff(&wire, termToCompute) - solution.set(wID, wire) - - - return nil -} - -// GetConstraints return the list of R1C and a coefficient resolver -func (cs *R1CS) GetConstraints() ([]constraint.R1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *R1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto -func (cs *R1CS) CurveID() ecc.ID { - return ecc.{{.CurveID}} -} - -// WriteTo encodes R1CS into provided io.Writer using cbor -func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - -// ReadFrom attempts to decode R1CS from io.Reader using cbor -func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(&cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - - return int64(decoder.NumBytesRead()), nil -} - diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl deleted file mode 100644 index f5a35ea03e..0000000000 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ /dev/null @@ -1,453 +0,0 @@ -import ( - "fmt" - "io" - "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark-crypto/ecc" - "sync" - "runtime" - "math" - "errors" - "time" - - "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/logger" - "github.com/consensys/gnark/profile" - "github.com/consensys/gnark/backend/witness" - - {{ template "import_fr" . }} -) - -// SparseR1CS represents a Plonk like circuit -type SparseR1CS struct { - constraint.SparseR1CSCore - CoeffTable - arithEngine -} - -// NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values -func NewSparseR1CS(capacity int) *SparseR1CS { - cs := SparseR1CS{ - SparseR1CSCore: constraint.SparseR1CSCore{ - System: constraint.NewSystem(fr.Modulus()), - Constraints: make([]constraint.SparseR1C, 0, capacity), - }, - CoeffTable: newCoeffTable(capacity / 10), - } - - return &cs -} - - -func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constraint.DebugInfo) int { - profile.RecordConstraint() - cs.Constraints = append(cs.Constraints, c) - cID := len(cs.Constraints) - 1 - if len(debugInfo) == 1 { - cs.DebugInfo = append(cs.DebugInfo, constraint.LogEntry(debugInfo[0])) - cs.MDebug[cID] = len(cs.DebugInfo) - 1 - } - cs.UpdateLevel(cID, &c) - - return cID -} - - -// Solve sets all the wires. -// solution.values = [publicInputs | secretInputs | internalVariables ] -// witness: contains the input variables -// it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { - log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() - - // set the slices holding the solution.values and monitoring which variables have been solved - nbVariables := cs.NbInternalVariables + len(cs.Secret) + len(cs.Public) - - start := time.Now() - - expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) - if len(witness) != expectedWitnessSize { - return make(fr.Vector, nbVariables), fmt.Errorf( - "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", - len(witness), - expectedWitnessSize, - len(cs.Public), - len(cs.Secret), - ) - } - - // keep track of wire that have a value - solution, err := newSolution( nbVariables, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) - if err != nil { - return solution.values, err - } - - - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs - copy(solution.values, witness) - for i := 0; i < len(witness); i++ { - solution.solved[i] = true - } - - // keep track of the number of wire instantiations we do, for a sanity check to ensure - // we instantiated all wires - solution.nbSolved += uint64(len(witness)) - - // defer log printing once all solution.values are computed - defer solution.printLogs(opt.CircuitLogger, cs.Logs) - - // batch invert the coefficients to avoid many divisions in the solver - coefficientsNegInv := fr.BatchInvert(cs.Coefficients) - for i:=0; i < len(coefficientsNegInv);i++ { - coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) - } - - if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { - if unsatisfiedErr, ok := err.(*UnsatisfiedConstraintError); ok { - log.Err(errors.New("unsatisfied constraint")).Int("id", unsatisfiedErr.CID).Send() - } else { - log.Err(err).Send() - } - return solution.values, err - } - - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - log.Err(errors.New("solver didn't instantiate all wires")).Send() - panic("solver didn't instantiate all wires") - } - - log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") - - return solution.values, nil - -} - - -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { - // minWorkPerCPU is the minimum target number of constraint a task should hold - // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed - // sequentially without sync. - const minWorkPerCPU = 50.0 - - // cs.Levels has a list of levels, where all constraints in a level l(n) are independent - // and may only have dependencies on previous levels - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan *UnsatisfiedConstraintError, runtime.NumCPU()) - - // start a worker pool - // each worker wait on chTasks - // a task is a slice of constraint indexes to be solved - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - // for each constraint in the task, solve it. - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - wg.Done() - return - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - chError <- &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } else { - chError <- &UnsatisfiedConstraintError{CID: i, Err: err} - } - wg.Done() - return - } - } - wg.Done() - } - }() - } - - // clean up pool go routines - defer func() { - close(chTasks) - close(chError) - }() - - // for each level, we push the tasks - for _, level := range cs.Levels { - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - - if maxCPU <= 1.0 { - // we do it sequentially - for _, i := range level { - if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { - if dID, ok := cs.MDebug[i]; ok { - errMsg := solution.logValue(cs.DebugInfo[dID]) - return &UnsatisfiedConstraintError{CID: i, DebugInfo: &errMsg} - } - return &UnsatisfiedConstraintError{CID: i, Err: err} - } - } - continue - } - - // number of tasks for this level is set to num cpus - // but if we don't have enough work for all our CPUS, it can be lower. - nbTasks := runtime.NumCPU() - maxTasks := int(math.Ceil(maxCPU)) - if nbTasks > maxTasks { - nbTasks = maxTasks - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - // note: this depends on minWorkPerCPU constant - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - // since we're never pushing more than num CPU tasks - // we will never be blocked here - chTasks <- level[_start:_end] - } - - // wait for the level to be done - wg.Wait() - - if len(chError) > 0 { - return <-chError - } - } - - return nil -} - - - -// computeHints computes wires associated with a hint function, if any -// if there is no remaining wire to solve, returns -1 -// else returns the wire position (L -> 0, R -> 1, O -> 2) -func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( int, error) { - r := -1 - lID, rID, oID := c.L.WireID(), c.R.WireID(), c.O.WireID() - - if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { - // check if it's a hint - if hint, ok := cs.MHints[lID]; ok { - if err := solution.solveWithHint(lID, hint); err != nil { - return -1, err - } - } else { - r = 0 - } - - } - - if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { - // check if it's a hint - if hint, ok := cs.MHints[rID]; ok { - if err := solution.solveWithHint(rID, hint); err != nil { - return -1, err - } - } else { - r = 1 - } - } - - if (c.O.CoeffID() != 0) && !solution.solved[oID] { - // check if it's a hint - if hint, ok := cs.MHints[oID]; ok { - if err := solution.solveWithHint(oID, hint); err != nil { - return -1, err - } - } else { - r = 2 - } - } - return r, nil -} - - -// solveConstraint solve any unsolved wire in given constraint and update the solution -// a SparseR1C may have up to one unsolved wire (excluding hints) -// if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { - - lro, err := cs.computeHints(c, solution) - if err != nil { - return err - } - if lro == -1 { - // no unsolved wire - // can happen if the constraint contained only hint wires. - return nil - } - if lro == 1 { // we solve for R: u1L+u2R+u3LR+u4O+k=0 => R(u2+u3L)+u1L+u4O+k = 0 - if !solution.solved[c.L.WireID()] { - panic("L wire should be instantiated when we solve R") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.L.WireID()]).Add(&den, &u2) - - v1 = solution.computeTerm(c.L) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - } - - if lro == 0 { // we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0 - if !solution.solved[c.R.WireID()] { - panic("R wire should be instantiated when we solve L") - } - var u1, u2, u3, den, num, v1, v2 fr.Element - u3.Mul(&cs.Coefficients[c.M[0].CoeffID()], &cs.Coefficients[c.M[1].CoeffID()]) - u1.Set(&cs.Coefficients[c.L.CoeffID()]) - u2.Set(&cs.Coefficients[c.R.CoeffID()]) - den.Mul(&u3, &solution.values[c.R.WireID()]).Add(&den, &u1) - - v1 = solution.computeTerm(c.R) - v2 = solution.computeTerm(c.O) - num.Add(&v1, &v2).Add(&num, &cs.Coefficients[c.K]) - - // TODO find a way to do lazy div (/ batch inversion) - num.Div(&num, &den).Neg(&num) - solution.set(c.L.WireID(), num) - return nil - - } - // O we solve for O - var o fr.Element - cID, vID := c.O.CoeffID(), c.O.WireID() - - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - - // o = - ((m0 * m1) + l + r + c.K) / c.O - o.Mul(&m0, &m1).Add(&o, &l).Add(&o, &r).Add(&o, &cs.Coefficients[c.K]) - o.Mul(&o, &coefficientsNegInv[cID]) - - solution.set(vID, o) - - return nil -} - -// IsSolved returns nil if given witness solves the SparseR1CS and error otherwise -// this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { - opt, err := backend.NewProverConfig(opts...) - if err != nil { - return err - } - - v := witness.Vector().(fr.Vector) - _, err = cs.Solve(v, opt) - return err -} - -// GetConstraints return the list of SparseR1C and a coefficient resolver -func (cs *SparseR1CS) GetConstraints() ([]constraint.SparseR1C, constraint.Resolver) { - return cs.Constraints, cs -} - -// checkConstraint verifies that the constraint holds -func (cs *SparseR1CS) checkConstraint(c constraint.SparseR1C, solution *solution) error { - l := solution.computeTerm(c.L) - r := solution.computeTerm(c.R) - m0 := solution.computeTerm(c.M[0]) - m1 := solution.computeTerm(c.M[1]) - o := solution.computeTerm(c.O) - - // l + r + (m0 * m1) + o + c.K == 0 - var t fr.Element - t.Mul(&m0, &m1).Add(&t, &l).Add(&t, &r).Add(&t, &o).Add(&t, &cs.Coefficients[c.K]) - if !t.IsZero() { - return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + (%s × %s) + %s != 0", - l.String(), - r.String(), - o.String(), - m0.String(), - m1.String(), - cs.Coefficients[c.K].String(), - ) - } - return nil - -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - return len(cs.Coefficients) -} - -// CurveID returns curve ID as defined in gnark-crypto (ecc.{{.Curve}}) -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.{{.CurveID}} -} - -// WriteTo encodes SparseR1CS into provided io.Writer using cbor -func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { - _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - enc, err := cbor.CoreDetEncOptions().EncMode() - if err != nil { - return 0, err - } - encoder := enc.NewEncoder(&_w) - - // encode our object - err = encoder.Encode(cs) - return _w.N, err -} - - -// ReadFrom attempts to decode SparseR1CS from io.Reader using cbor -func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { - dm, err := cbor.DecOptions{ - MaxArrayElements: 134217728, - MaxMapPairs: 134217728, - }.DecMode() - if err != nil { - return 0, err - } - decoder := dm.NewDecoder(r) - - // initialize coeff table - cs.CoeffTable = newCoeffTable(0) - - if err := decoder.Decode(cs); err != nil { - return int64(decoder.NumBytesRead()), err - } - - if err := cs.CheckSerializationHeader(); err != nil { - return int64(decoder.NumBytesRead()), err - } - - return int64(decoder.NumBytesRead()), nil -} diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl deleted file mode 100644 index e323a614f5..0000000000 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ /dev/null @@ -1,279 +0,0 @@ -import ( - "errors" - "fmt" - "math/big" - "sync/atomic" - "strings" - "strconv" - "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/constraint" - "github.com/rs/zerolog" - {{ template "import_fr" . }} -) - -// solution represents elements needed to compute -// a solution to a R1CS or SparseR1CS -type solution struct { - values, coefficients []fr.Element - solved []bool - nbSolved uint64 - mHintsFunctions map[hint.ID]hint.Function // maps hintID to hint function - mHints map[int]*constraint.Hint // maps wireID to hint - st *debug.SymbolTable -} - -func newSolution( nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, mHints map[int]*constraint.Hint, coefficients []fr.Element, st *debug.SymbolTable) (solution, error) { - - s := solution{ - st: st, - values: make([]fr.Element, nbWires), - coefficients: coefficients, - solved: make([]bool, nbWires), - mHintsFunctions: hintFunctions, - mHints: mHints, - } - - // hintsDependencies is from compile time; it contains the list of hints the solver **needs** - var missing []string - for hintUUID, hintID := range hintsDependencies { - if _, ok := s.mHintsFunctions[hintUUID]; !ok { - missing = append(missing, hintID) - } - } - - if len(missing) > 0 { - return s, fmt.Errorf("solver missing hint(s): %v", missing) - } - - return s, nil -} - -func (s *solution) set(id int, value fr.Element) { - if s.solved[id] { - panic("solving the same wire twice should never happen.") - } - s.values[id] = value - s.solved[id] = true - atomic.AddUint64(&s.nbSolved, 1) - // s.nbSolved++ -} - -func (s *solution) isValid() bool { - return int(s.nbSolved) == len(s.values) -} - -// computeTerm computes coeff*variable -func (s *solution) computeTerm(t constraint.Term) fr.Element { - cID, vID := t.CoeffID(), t.WireID() - if cID != 0 && !s.solved[vID] { - panic("computing a term with an unsolved wire") - } - switch cID { - case constraint.CoeffIdZero: - return fr.Element{} - case constraint.CoeffIdOne: - return s.values[vID] - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - return res - case constraint.CoeffIdMinusOne: - var res fr.Element - res.Neg(&s.values[vID]) - return res - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - return res - } -} - -// r += (t.coeff*t.value) -func (s *solution) accumulateInto(t constraint.Term, r *fr.Element) { - cID := t.CoeffID() - - if t.IsConstant() { - // needed for logs, we may want to not put this in the hot path if we need to - // optimize constraint system solver further. - r.Add(r, &s.coefficients[cID]) - return - } - - vID := t.WireID() - switch cID { - case constraint.CoeffIdZero: - return - case constraint.CoeffIdOne: - r.Add(r, &s.values[vID]) - case constraint.CoeffIdTwo: - var res fr.Element - res.Double(&s.values[vID]) - r.Add(r, &res) - case constraint.CoeffIdMinusOne: - r.Sub(r, &s.values[vID]) - default: - var res fr.Element - res.Mul(&s.coefficients[cID], &s.values[vID]) - r.Add(r, &res) - } -} - -// solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveWithHint(vID int, h *constraint.Hint) error { - // skip if the wire is already solved by a call to the same hint - // function on the same inputs - if s.solved[vID] { - return nil - } - // ensure hint function was provided - f, ok := s.mHintsFunctions[h.ID] - if !ok { - return errors.New("missing hint function") - } - - // tmp IO big int memory - nbInputs := len(h.Inputs) - nbOutputs := len(h.Wires) - inputs := make([]*big.Int, nbInputs) - outputs := make([]*big.Int, nbOutputs) - for i :=0; i < nbOutputs; i++ { - outputs[i] = big.NewInt(0) - } - - q := fr.Modulus() - - // for each input, we set its big int value, IF all the wires are solved - // the only case where all wires may not be solved, is if one of the input of this hint - // is the output of another hint. - // it is safe to recursively solve this with the parallel solver, since all hints-output wires - // that we can solve this way are marked to be solved with the current constraint we are processing. - recursiveSolve := func(t constraint.Term) error { - if t.IsConstant() { - return nil - } - wID := t.WireID() - if s.solved[wID] { - return nil - } - // unsolved dependency - if h, ok := s.mHints[wID]; ok { - // solve recursively. - return s.solveWithHint(wID, h) - } - - // it's not a hint, we panic. - panic("solver can't compute hint; one or more input wires are unsolved") - } - - for i := 0; i < nbInputs; i++ { - inputs[i] = big.NewInt(0) - - var v fr.Element - for _, term := range h.Inputs[i] { - if err := recursiveSolve(term); err != nil { - return err - } - s.accumulateInto(term, &v) - } - v.BigInt(inputs[i]) - } - - - err := f(q, inputs, outputs) - - var v fr.Element - for i := range outputs { - v.SetBigInt(outputs[i]) - s.set(h.Wires[i], v) - } - - return err -} - -func (s *solution) printLogs(log zerolog.Logger, logs []constraint.LogEntry) { - if log.GetLevel() == zerolog.Disabled { - return - } - - for i := 0; i < len(logs); i++ { - logLine := s.logValue(logs[i]) - log.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) - } -} - -const unsolvedVariable = "" - -func (s *solution) logValue(log constraint.LogEntry) string { - var toResolve []interface{} - var ( - eval fr.Element - missingValue bool - ) - for j := 0; j < len(log.ToResolve); j++ { - // before eval le - - missingValue = false - eval.SetZero() - - for _, t := range log.ToResolve[j] { - // for each term in the linear expression - - cID, vID := t.CoeffID(), t.WireID() - if t.IsConstant() { - // just add the constant - eval.Add(&eval, &s.coefficients[cID]) - continue - } - - if !s.solved[vID] { - missingValue = true - break // stop the loop we can't evaluate. - } - - tv := s.computeTerm(t) - eval.Add(&eval, &tv) - } - - - // after - if missingValue { - toResolve = append(toResolve, unsolvedVariable) - } else { - // we have to append our accumulator - toResolve = append(toResolve, eval.String()) - } - - } - if len(log.Stack) > 0 { - var sbb strings.Builder - for _, lID := range log.Stack { - location := s.st.Locations[lID] - function := s.st.Functions[location.FunctionID] - - sbb.WriteString(function.Name) - sbb.WriteByte('\n') - sbb.WriteByte('\t') - sbb.WriteString(function.Filename) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(int(location.Line))) - sbb.WriteByte('\n') - } - toResolve = append(toResolve, sbb.String()) - } - return fmt.Sprintf(log.Format, toResolve...) -} - -// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint -type UnsatisfiedConstraintError struct { - Err error - CID int // constraint ID - DebugInfo *string // optional debug info -} - -func (r *UnsatisfiedConstraintError) Error() string { - if r.DebugInfo != nil { - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) - } - return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) -} \ No newline at end of file diff --git a/internal/generator/backend/template/representations/solver.go.tmpl b/internal/generator/backend/template/representations/solver.go.tmpl new file mode 100644 index 0000000000..aa7b60cd4d --- /dev/null +++ b/internal/generator/backend/template/representations/solver.go.tmpl @@ -0,0 +1,650 @@ +import ( + "errors" + "fmt" + "math/big" + "sync/atomic" + "strings" + "strconv" + "runtime" + "sync" + "math" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + {{ template "import_fr" . }} +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + + a,b,c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public)-witnessOffset+len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i :=0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start) + i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + + + + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID:cID,VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k:= 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID:calldata[j],VID: calldata[j+1]} , &r) + j+=2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + + + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint,inst) + return solver.solveWithHint(&scratch.tHint) + } + + + return nil +} + + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + + + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID],&solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} + diff --git a/internal/generator/backend/template/representations/system.go.tmpl b/internal/generator/backend/template/representations/system.go.tmpl new file mode 100644 index 0000000000..939dc150e7 --- /dev/null +++ b/internal/generator/backend/template/representations/system.go.tmpl @@ -0,0 +1,376 @@ +import ( + "io" + "time" + "github.com/fxamacker/cbor/v2" + + "github.com/consensys/gnark/internal/backend/ioutils" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/backend/witness" + "reflect" + + "github.com/consensys/gnark-crypto/ecc" + + {{ template "import_fr" . }} +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + + +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } else { + panic("not implemented") + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.{{.CurveID}} +} + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written + ts := getTagSet() + enc, err := cbor.CoreDetEncOptions().EncModeWithTags(ts) + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) + + // encode our object + err = encoder.Encode(cs) + return _w.N, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + ts := getTagSet() + dm, err := cbor.DecOptions{ + MaxArrayElements: 2147483647, + MaxMapPairs: 2147483647, + }.DecModeWithTags(ts) + + if err != nil { + return 0, err + } + decoder := dm.NewDecoder(r) + + // initialize coeff table + cs.CoeffTable = newCoeffTable(0) + + if err := decoder.Decode(&cs); err != nil { + return int64(decoder.NumBytesRead()), err + } + + + if err := cs.CheckSerializationHeader(); err != nil { + return int64(decoder.NumBytesRead()), err + } + + switch v := cs.CommitmentInfo.(type) { + case *constraint.Groth16Commitments: + cs.CommitmentInfo = *v + case *constraint.PlonkCommitments: + cs.CommitmentInfo = *v + } + + return int64(decoder.NumBytesRead()), nil +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } else { + panic("not implemented") + } + } + return toReturn +} + + + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s) + r = make([]fr.Element, s) + o = make([]fr.Element, s) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + + + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + + + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + + +func getTagSet() cbor.TagSet { + // temporary for refactor + ts := cbor.NewTagSet() + // https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml + // 65536-15309735 Unassigned + tagNum := uint64(5309735) + addType := func(t reflect.Type) { + if err := ts.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + t, + tagNum, + ); err != nil { + panic(err) + } + tagNum++ + } + + addType(reflect.TypeOf(constraint.BlueprintGenericHint{})) + addType(reflect.TypeOf(constraint.BlueprintGenericR1C{})) + addType(reflect.TypeOf(constraint.BlueprintGenericSparseR1C{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CAdd{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CMul{})) + addType(reflect.TypeOf(constraint.BlueprintSparseR1CBool{})) + addType(reflect.TypeOf(constraint.BlueprintLookupHint{})) + addType(reflect.TypeOf(constraint.Groth16Commitments{})) + addType(reflect.TypeOf(constraint.PlonkCommitments{})) + + return ts +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} \ No newline at end of file diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index c312e2f48b..81eebde6bc 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -5,6 +5,7 @@ import ( "reflect" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "github.com/google/go-cmp/cmp" @@ -39,7 +40,7 @@ func TestSerialization(t *testing.T) { return } - // copmpile a second time to ensure determinism + // compile a second time to ensure determinism r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) @@ -67,10 +68,10 @@ func TestSerialization(t *testing.T) { if diff := cmp.Diff(r1cs1, &reconstructed, cmpopts.IgnoreFields(cs.R1CS{}, "System.q", - "arithEngine", + "field", "CoeffTable.mCoeffs", "System.lbWireLevel", - "System.lbHints", + "System.genericHint", "System.SymbolTable", "System.lbOutputs", "System.bitLen")); diff != "" { @@ -139,11 +140,6 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { - var c circuit - ccs, err := frontend.Compile(fr.Modulus(),r1cs.NewBuilder, &c) - if err != nil { - b.Fatal(err) - } var w circuit w.X = 1 @@ -153,9 +149,32 @@ func BenchmarkSolve(b *testing.B) { b.Fatal(err) } + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(),scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(),r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ccs.IsSolved(witness) - } } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl index da74e65321..aec745a7d6 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.commitment.go.tmpl @@ -5,7 +5,7 @@ import ( "math/big" ) -func solveCommitmentWire(commitmentInfo *constraint.Commitment, commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { - res, err := fr.Hash(commitmentInfo.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) +func solveCommitmentWire(commitment *curve.G1Affine, publicCommitted []*big.Int) (fr.Element, error) { + res, err := fr.Hash(constraint.SerializeCommitment(commitment.Marshal(), publicCommitted, (fr.Bits-1)/8+1), []byte(constraint.CommitmentDst), 1) return res[0], err } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl index 7414dab596..eaa9835d95 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.marshal.go.tmpl @@ -1,5 +1,7 @@ import ( {{ template "import_curve" . }} + {{ template "import_pedersen" . }} + "github.com/consensys/gnark/internal/utils" "io" ) @@ -61,14 +63,24 @@ func (proof *Proof) ReadFrom(r io.Reader) (n int64, err error) { // points are compressed // use WriteRawTo(...) to encode the key without point compression func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, false) + if n, err = vk.writeTo(w, false); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteTo(w) + return m + n, err } // WriteRawTo writes binary encoding of the key elements to writer // points are not compressed // use WriteTo(...) to encode the key with point compression func (vk *VerifyingKey) WriteRawTo(w io.Writer) (n int64, err error) { - return vk.writeTo(w, true) + if n, err = vk.writeTo(w, true); err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.WriteRawTo(w) + return m + n, err } // writeTo serialization format: @@ -108,6 +120,14 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { if err := enc.Encode(vk.G1.K); err != nil { return enc.BytesWritten(), err } + + if vk.PublicAndCommitmentCommitted == nil { + vk.PublicAndCommitmentCommitted = [][]int{} // only matters in tests + } + if err := enc.Encode(utils.IntSliceSliceToUint64SliceSlice(vk.PublicAndCommitmentCommitted)); err != nil { + return enc.BytesWritten(), err + } + return enc.BytesWritten(), nil } @@ -117,13 +137,25 @@ func (vk *VerifyingKey) writeTo(w io.Writer, raw bool) (int64, error) { // https://github.com/zkcrypto/bellman/blob/fa9be45588227a8c6ec34957de3f68705f07bd92/src/groth16/mod.rs#L143 // [α]1,[β]1,[β]2,[γ]2,[δ]1,[δ]2,uint32(len(Kvk)),[Kvk]1 func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r) + n, err := vk.readFrom(r) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.ReadFrom(r) + return m + n, err } // UnsafeReadFrom has the same behavior as ReadFrom, except that it will not check that decode points // are on the curve and in the correct subgroup. func (vk *VerifyingKey) UnsafeReadFrom(r io.Reader) (int64, error) { - return vk.readFrom(r, curve.NoSubgroupChecks()) + n, err := vk.readFrom(r, curve.NoSubgroupChecks()) + if err != nil { + return n, err + } + var m int64 + m, err = vk.CommitmentKey.UnsafeReadFrom(r) + return m + n, err } func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) (int64, error) { @@ -153,15 +185,16 @@ func (vk *VerifyingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder) if err := dec.Decode(&vk.G1.K); err != nil { return dec.BytesRead(), err } + var publicCommitted [][]uint64 + if err := dec.Decode(&publicCommitted); err != nil { + return dec.BytesRead(), err + } + vk.PublicAndCommitmentCommitted = utils.Uint64SliceSliceToIntSliceSlice(publicCommitted) // recompute vk.e (e(α, β)) and -[δ]2, -[γ]2 - var err error - vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) - if err != nil { - return dec.BytesRead(), err + if err := vk.Precompute(); err != nil { + return dec.BytesRead(), err } - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) return dec.BytesRead(), nil } @@ -213,6 +246,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { pk.NbInfinityB, pk.InfinityA, pk.InfinityB, + uint32(len(pk.CommitmentKeys)), } for _, v := range toEncode { @@ -221,6 +255,23 @@ func (pk *ProvingKey) writeTo(w io.Writer, raw bool) (int64, error) { } } + for i := range pk.CommitmentKeys { + var ( + n2 int64 + err error + ) + if raw { + n2, err = pk.CommitmentKeys[i].WriteRawTo(w) + } else { + n2, err = pk.CommitmentKeys[i].WriteTo(w) + } + + n += n2 + if err != nil { + return n, err + } + } + return n + enc.BytesWritten(), nil } @@ -248,6 +299,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) dec := curve.NewDecoder(r, decOptions...) var nbWires uint64 + var nbCommitments uint32 toDecode := []interface{}{ &pk.G1.Alpha, @@ -279,6 +331,18 @@ func (pk *ProvingKey) readFrom(r io.Reader, decOptions ...func(*curve.Decoder)) if err := dec.Decode(&pk.InfinityB); err != nil { return n + dec.BytesRead(), err } + if err := dec.Decode(&nbCommitments); err != nil { + return n + dec.BytesRead(), err + } + + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitments) + for i := range pk.CommitmentKeys { + n2, err := pk.CommitmentKeys[i].ReadFrom(r) + n += n2 + if err != nil { + return n, err + } + } return n + dec.BytesRead(), nil } diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 64987ebd1b..fa72568220 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -1,15 +1,19 @@ import ( - "fmt" {{- template "import_fr" . }} {{- template "import_curve" . }} {{- template "import_backend_cs" . }} {{- template "import_fft" . }} + {{- template "import_pedersen" .}} "runtime" "math/big" "time" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16/internal" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/logger" ) @@ -20,7 +24,8 @@ import ( type Proof struct { Ar, Krs curve.G1Affine Bs curve.G2Affine - Commitment, CommitmentPok curve.G1Affine + Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 + CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments } // isValid ensures proof elements are in the correct subgroup @@ -34,83 +39,89 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { - // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. - // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { - // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) - // } - - log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", len(r1cs.Constraints)).Str("backend", "groth16").Logger() - - // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) - - proof := &Proof{} - if r1cs.CommitmentInfo.Is() { - opt.HintFunctions[r1cs.CommitmentInfo.HintID] = func(_ *big.Int, in []*big.Int, out []*big.Int) error { - // Perf-TODO: Converting these values to big.Int and back may be a performance bottleneck. - // If that is the case, figure out a way to feed the solution vector into this function - if len(in) != r1cs.CommitmentInfo.NbCommitted() { // TODO: Remove - return fmt.Errorf("unexpected number of committed variables") - } - values := make([]fr.Element, r1cs.CommitmentInfo.NbPrivateCommitted) - nbPublicCommitted := len(in) - len(values) - inPrivate := in[nbPublicCommitted:] - for i, inI := range inPrivate { - values[i].SetBigInt(inI) - } +func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } + + log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() + + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + + proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} + + solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] - var err error - proof.Commitment, proof.CommitmentPok, err = pk.CommitmentKey.Commit(values) - if err != nil { + privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { + return func(_ *big.Int, in []*big.Int, out []*big.Int) error { + privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) + hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] + committed := in[len(hashed):] + for j, inJ := range committed { + privateCommittedValues[i][j].SetBigInt(inJ) + } + + var err error + if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { + return err + } + + var res fr.Element + res, err = solveCommitmentWire(&proof.Commitments[i], hashed) + res.BigInt(out[0]) return err } + }(i))) + } - var res fr.Element - res, err = solveCommitmentWire(&r1cs.CommitmentInfo, &proof.Commitment, in[:r1cs.CommitmentInfo.NbPublicCommitted()]) - res.BigInt(out[0]) //Perf-TODO: Regular (non-mont) hashToField to obviate this conversion? - return err - } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) } - var wireValues []fr.Element - var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill wireValues with random values else multi exps don't do much - var r fr.Element - _, _ = r.SetRandom() - for i := r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables(); i < len(wireValues); i++ { - wireValues[i] = r - r.Double(&r) - } - } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) + if err != nil { + return nil, err } + + solution := _solution.(*cs.R1CSSolution) + wireValues := []fr.Element(solution.W) + start := time.Now() + commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) + for i := range commitmentInfo { + copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) + } + + if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { + return nil, err + } + // H (witness reduction / FFT part) var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(a, b, c, &pk.Domain) - a = nil - b = nil - c = nil + h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + solution.A = nil + solution.B = nil + solution.C = nil chHDone <- struct{}{} }() // we need to copy and filter the wireValues for each multi exp // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity - var wireValuesA, wireValuesB []fr.Element - chWireValuesA, chWireValuesB := make(chan struct{}, 1) , make(chan struct{}, 1) + var wireValuesA, wireValuesB []fr.Element + chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) go func() { - wireValuesA = make([]fr.Element , len(wireValues) - int(pk.NbInfinityA)) - for i,j :=0,0; j len(toRemove) -func filter(slice []fr.Element, toRemove []int) (r []fr.Element) { +// else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex +// this assumes len(slice) > len(toRemove) +// filterHeap modifies toRemove +func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { if len(toRemove) == 0 { return slice } - r = make([]fr.Element, 0, len(slice)-len(toRemove)) + + heap := utils.IntHeap(toRemove) + heap.Heapify() + + r = make([]fr.Element, 0, len(slice)) - j := 0 // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) for i:=0; i < len(slice);i++ { - if j < len(toRemove) && i == toRemove[j] { - j++ + if len(heap) > 0 && i+sliceFirstIndex == heap[0] { + for len(heap) > 0 && i+sliceFirstIndex == heap[0] { + heap.Pop() + } continue } r = append(r, slice[i]) } - return r + return } func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { @@ -317,9 +338,9 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { domain.FFTInverse(b, fft.DIF) domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, true) - domain.FFT(b, fft.DIT, true) - domain.FFT(c, fft.DIT, true) + domain.FFT(a, fft.DIT, fft.OnCoset()) + domain.FFT(b, fft.DIT, fft.OnCoset()) + domain.FFT(c, fft.DIT, fft.OnCoset()) var den, one fr.Element one.SetOne() @@ -327,7 +348,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unecessary memalloc + // reusing a to avoid unnecessary memory allocation utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -337,7 +358,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { }) // ifft_coset - domain.FFTInverse(a, fft.DIF, true) + domain.FFTInverse(a, fft.DIF, fft.OnCoset()) return a } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 1c7ca7dbbb..1795c3987d 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -1,9 +1,11 @@ import ( + "errors" {{- template "import_fr" . }} {{- template "import_curve" . }} {{- template "import_backend_cs" . }} {{- template "import_fft" . }} {{- template "import_pedersen" .}} + "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint" "math/big" @@ -16,15 +18,15 @@ type ProvingKey struct { // domain Domain fft.Domain - // [α]1, [β]1, [δ]1 - // [A(t)]1, [B(t)]1, [Kpk(t)]1, [Z(t)]1 + // [α]₁, [β]₁, [δ]₁ + // [A(t)]₁, [B(t)]₁, [Kpk(t)]₁, [Z(t)]₁ G1 struct { Alpha, Beta, Delta curve.G1Affine A, B, Z []curve.G1Affine K []curve.G1Affine // the indexes correspond to the private wires } - // [β]2, [δ]2, [B(t)]2 + // [β]₂, [δ]₂, [B(t)]₂ G2 struct { Beta, Delta curve.G2Affine B []curve.G2Affine @@ -34,21 +36,21 @@ type ProvingKey struct { InfinityA, InfinityB []bool NbInfinityA, NbInfinityB uint64 - CommitmentKey pedersen.Key + CommitmentKeys []pedersen.ProvingKey } // VerifyingKey is used by a Groth16 verifier to verify the validity of a proof and a statement // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf type VerifyingKey struct { - // [α]1, [Kvk]1 + // [α]₁, [Kvk]₁ G1 struct { Alpha curve.G1Affine Beta, Delta curve.G1Affine // unused, here for compatibility purposes K []curve.G1Affine // The indexes correspond to the public wires } - // [β]2, [δ]2, [γ]2, - // -[δ]2, -[γ]2: see proof.Verify() for more details + // [β]₂, [δ]₂, [γ]₂, + // -[δ]₂, -[γ]₂: see proof.Verify() for more details G2 struct { Beta, Delta, Gamma curve.G2Affine deltaNeg, gammaNeg curve.G2Affine // not serialized @@ -57,8 +59,8 @@ type VerifyingKey struct { // e(α, β) e curve.GT // not serialized - CommitmentKey pedersen.Key - CommitmentInfo constraint.Commitment // since the verifier doesn't input a constraint system, this needs to be provided here + CommitmentKey pedersen.VerifyingKey + PublicAndCommitmentCommitted [][]int // indexes of public/commitment committed variables } // Setup constructs the SRS @@ -75,17 +77,20 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbPrivateCommittedWires := r1cs.CommitmentInfo.NbPrivateCommitted - nbPublicWires := r1cs.GetNbPublicVariables() - nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - if r1cs.CommitmentInfo.Is() { // the commitment itself is defined by a hint so the prover considers it private - nbPublicWires++ // but the verifier will need to inject the value itself so on the groth16 - nbPrivateWires-- // level it must be considered public - } + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + commitmentWires := commitmentInfo.CommitmentIndexes() + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateCommittedWires := internal.NbElements(privateCommitted) + + // a commitment is itself defined by a hint so the prover considers it private + // but the verifier will need to inject the value itself so on the groth16 + // level it must be considered public + nbPublicWires := r1cs.GetNbPublicVariables() + len(commitmentInfo) + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - nbPrivateCommittedWires - len(commitmentInfo) // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints))) + domain := fft.NewDomain(uint64(r1cs.GetNbConstraints())) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -119,7 +124,10 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // compute scalars for pkK, vkK and ckK pkK := make([]fr.Element, nbPrivateWires) vkK := make([]fr.Element, nbPublicWires) - ckK := make([]fr.Element, nbPrivateCommittedWires) + ckK := make([][]fr.Element, len(commitmentInfo)) + for i := range commitmentInfo { + ckK[i] = make([]fr.Element, len(privateCommitted[i])) + } var t0, t1 fr.Element @@ -130,28 +138,42 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { Add(&t1, &C[i]). Mul(&t1, coeff) } - - vI, cI := 0, 0 - privateCommitted := r1cs.CommitmentInfo.PrivateCommitted() + vI := 0 // number of public wires seen so far + cI := make([]int, len(commitmentInfo)) // number of private committed wires seen so far for each commitment + nbPrivateCommittedSeen := 0 // = ∑ᵢ cI[i] + nbCommitmentsSeen := 0 for i := range A { - isCommittedPrivate := cI < len(privateCommitted) && i == privateCommitted[cI] - isCommitment := r1cs.CommitmentInfo.Is() && i == r1cs.CommitmentInfo.CommitmentIndex - isPublic := i < r1cs.GetNbPublicVariables() + commitment := -1 // index of the commitment that commits to this variable as a private or commitment value + var isCommitment, isPublic bool + if isPublic = i < r1cs.GetNbPublicVariables(); !isPublic { + if nbCommitmentsSeen < len(commitmentWires) && commitmentWires[nbCommitmentsSeen] == i { + isCommitment = true + nbCommitmentsSeen++ + } - if isPublic || isCommittedPrivate || isCommitment { + for j := range commitmentInfo { // does commitment j commit to i? + if cI[j] < len(privateCommitted[j]) && privateCommitted[j][cI[j]] == i { + commitment = j + break // frontend guarantees that no private variable is committed to more than once + } + } + } + + if isPublic || commitment != -1 || isCommitment { computeK(i, &toxicWaste.gammaInv) - if isCommittedPrivate { - ckK[cI] = t1 - cI++ - } else { + if isPublic || isCommitment { vkK[vI] = t1 vI++ + } else { // committed and private + ckK[commitment][cI[commitment]] = t1 + cI[commitment]++ + nbPrivateCommittedSeen++ } } else { computeK(i, &toxicWaste.deltaInv) - pkK[i-vI-cI] = t1 + pkK[i-vI-nbPrivateCommittedSeen] = t1 // vI = nbPublicSeen + nbCommitmentsSeen } } @@ -204,11 +226,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { g1Scalars = append(g1Scalars, Z...) g1Scalars = append(g1Scalars, vkK...) g1Scalars = append(g1Scalars, pkK...) - g1Scalars = append(g1Scalars, ckK...) + for i := range ckK { + g1Scalars = append(g1Scalars, ckK[i]...) + } g1PointsAff := curve.BatchScalarMultiplicationG1(&g1, g1Scalars) - // sets pk: [α]1, [β]1, [δ]1 + // sets pk: [α]₁, [β]₁, [δ]₁ pk.G1.Alpha = g1PointsAff[0] pk.G1.Beta = g1PointsAff[1] pk.G1.Delta = g1PointsAff[2] @@ -220,8 +244,9 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G1.B = g1PointsAff[offset : offset+len(B)] offset += len(B) - pk.G1.Z = g1PointsAff[offset : offset+int(domain.Cardinality)] - bitReverse(pk.G1.Z) + bitReverse(g1PointsAff[offset : offset+int(domain.Cardinality)]) + sizeZ := int(domain.Cardinality)-1 // deg(H)=deg(A*B-C/X^n-1)=(n-1)+(n-1)-n=n-2 + pk.G1.Z = g1PointsAff[offset : offset+sizeZ] offset += int(domain.Cardinality) @@ -234,17 +259,22 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { // --------------------------------------------------------------------------------------------- // Commitment setup - if nbPrivateCommittedWires != 0 { - commitmentBasis := g1PointsAff[offset:] + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(ckK[i]) + commitmentBases[i] = g1PointsAff[offset : offset+size] + offset += size + } + if offset != len(g1PointsAff) { + return errors.New("didn't consume all G1 points") // TODO @Tabaie Remove this + } - vk.CommitmentKey, err = pedersen.Setup(commitmentBasis) - if err != nil { - return err - } - pk.CommitmentKey = vk.CommitmentKey + pk.CommitmentKeys, vk.CommitmentKey, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err } - vk.CommitmentInfo = r1cs.CommitmentInfo // unfortunate but necessary + vk.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentWires, r1cs.GetNbPublicVariables()) // --------------------------------------------------------------------------------------------- // G2 scalars @@ -261,15 +291,13 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { pk.G2.B = g2PointsAff[:len(B)] - // sets pk: [β]2, [δ]2 + // sets pk: [β]₂, [δ]₂ pk.G2.Beta = g2PointsAff[len(B)+0] pk.G2.Delta = g2PointsAff[len(B)+1] - // sets vk: [δ]2, [γ]2, -[δ]2, -[γ]2 + // sets vk: [δ]₂, [γ]₂ vk.G2.Delta = g2PointsAff[len(B)+1] vk.G2.Gamma = g2PointsAff[len(B)+2] - vk.G2.deltaNeg.Neg(&vk.G2.Delta) - vk.G2.gammaNeg.Neg(&vk.G2.Gamma) // --------------------------------------------------------------------------------------------- // Pairing: vk.e @@ -280,16 +308,29 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { vk.G1.Beta = pk.G1.Beta vk.G1.Delta = pk.G1.Delta - vk.e, err = curve.Pair([]curve.G1Affine{pk.G1.Alpha}, []curve.G2Affine{pk.G2.Beta}) - if err != nil { - return err + if err := vk.Precompute(); err != nil { + return err } + // set domain pk.Domain = *domain return nil } +// Precompute sets e, -[δ]₂, -[γ]₂ +// This is meant to be called internally during setup or deserialization. +func (vk *VerifyingKey) Precompute() error { + var err error + vk.e, err = curve.Pair([]curve.G1Affine{vk.G1.Alpha}, []curve.G2Affine{vk.G2.Beta}) + if err != nil { + return err + } + vk.G2.deltaNeg.Neg(&vk.G2.Delta) + vk.G2.gammaNeg.Neg(&vk.G2.Gamma) + return nil +} + func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr.Element, B []fr.Element, C []fr.Element) { nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() @@ -304,7 +345,7 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. var w fr.Element w.Set(&domain.Generator) wi := fr.One() - t := make([]fr.Element, len(r1cs.Constraints)+1) + t := make([]fr.Element, r1cs.GetNbConstraints()+1) for i := 0; i < len(t); i++ { t[i].Sub(&toxicWaste.t, &wi) wi.Mul(&wi, &w) // TODO this is already pre computed in fft.Domain @@ -348,8 +389,10 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // for each term appearing in the linear expression, // we compute term.Coefficient * L, and cumulate it in // A, B or C at the index of the variable - for i, c := range r1cs.Constraints { + j := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } @@ -362,9 +405,12 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // Li+1 = w*Li*(t-w^i)/(t-w^(i+1)) L.Mul(&L, &w) - L.Mul(&L, &t[i]) - L.Mul(&L, &tInv[i+1]) + L.Mul(&L, &t[j]) + L.Mul(&L, &tInv[j+1]) + + j++ } + return } @@ -418,7 +464,10 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.GetNbPublicVariables() + r1cs.GetNbSecretVariables() - nbConstraints := len(r1cs.Constraints) + nbConstraints := r1cs.GetNbConstraints() + commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) + privateCommitted := commitmentInfo.GetPrivateCommitted() + nbPrivateWires := r1cs.GetNbSecretVariables() + r1cs.NbInternalVariables - internal.NbElements(privateCommitted) - len(commitmentInfo) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints)) @@ -430,8 +479,8 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // initialize proving key pk.G1.A = make([]curve.G1Affine, nbWires-nbZeroesA) pk.G1.B = make([]curve.G1Affine, nbWires-nbZeroesB) - pk.G1.K = make([]curve.G1Affine, nbWires-r1cs.GetNbPublicVariables()) - pk.G1.Z = make([]curve.G1Affine, domain.Cardinality) + pk.G1.K = make([]curve.G1Affine, nbPrivateWires) + pk.G1.Z = make([]curve.G1Affine, domain.Cardinality-1) pk.G2.B = make([]curve.G2Affine, nbWires-nbZeroesB) // set infinity markers @@ -485,6 +534,22 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { pk.Domain = *domain + // --------------------------------------------------------------------------------------------- + // Commitment setup + commitmentBases := make([][]curve.G1Affine, len(commitmentInfo)) + for i := range commitmentBases { + size := len(privateCommitted[i]) + commitmentBases[i] = make([]curve.G1Affine, size) + for j := range commitmentBases[i] { + commitmentBases[i][j] = r1Aff + } + } + + pk.CommitmentKeys,_, err = pedersen.Setup(commitmentBases...) + if err != nil { + return err + } + return nil } @@ -496,7 +561,9 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) - for _, c := range r1cs.Constraints { + + it := r1cs.GetR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { for _, t := range c.L { A[t.WireID()] = true } @@ -504,6 +571,8 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { B[t.WireID()] = true } } + + for i := 0; i < nbWires; i++ { if !A[i] { nbZeroesA++ @@ -588,7 +657,7 @@ func (pk *ProvingKey) NbG2() int { return 2 + len(pk.G2.B) } -// bitRerverse permutation as in fft.BitReverse , but with []curve.G1Affine +// bitReverse permutation as in fft.BitReverse , but with []curve.G1Affine func bitReverse(a []curve.G1Affine) { n := uint(len(a)) nn := uint(bits.UintSize - bits.TrailingZeros(n)) diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl index 29a7202c08..f87b9197dd 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.verify.go.tmpl @@ -6,10 +6,13 @@ import ( "errors" "time" "io" - "math/big" + {{- if eq .Curve "BN254"}} "text/template" {{- end}} + {{- template "import_pedersen" .}} + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/logger" ) @@ -21,10 +24,8 @@ var ( // Verify verifies a proof with given VerifyingKey and publicWitness func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - nbPublicVars := len(vk.G1.K) - if vk.CommitmentInfo.Is() { - nbPublicVars-- - } + nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted) + if len(publicWitness) != nbPublicVars-1 { return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(publicWitness), len(vk.G1.K) - 1) } @@ -47,21 +48,32 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { close(chDone) }() - if vk.CommitmentInfo.Is() { - - if err := vk.CommitmentKey.VerifyKnowledgeProof(proof.Commitment, proof.CommitmentPok); err != nil { - return err + maxNbPublicCommitted := 0 + for _, s := range vk.PublicAndCommitmentCommitted { // iterate over commitments + maxNbPublicCommitted = utils.Max(maxNbPublicCommitted, len(s)) + } + commitmentsSerialized := make([]byte, len(vk.PublicAndCommitmentCommitted)*fr.Bytes) + commitmentPrehashSerialized := make([]byte, curve.SizeOfG1AffineUncompressed+maxNbPublicCommitted*fr.Bytes) + for i := range vk.PublicAndCommitmentCommitted { // solveCommitmentWire + copy(commitmentPrehashSerialized, proof.Commitments[i].Marshal()) + offset := curve.SizeOfG1AffineUncompressed + for j := range vk.PublicAndCommitmentCommitted[i] { + copy(commitmentPrehashSerialized[offset:], publicWitness[vk.PublicAndCommitmentCommitted[i][j]-1].Marshal()) + offset += fr.Bytes } - - publicCommitted := make([]*big.Int, vk.CommitmentInfo.NbPublicCommitted()) - for i := range publicCommitted { - var b big.Int - publicWitness[vk.CommitmentInfo.Committed[i]-1].BigInt(&b) - publicCommitted[i] = &b + if res, err := fr.Hash(commitmentPrehashSerialized[:offset], []byte(constraint.CommitmentDst), 1); err != nil { + return err + } else { + publicWitness = append(publicWitness, res[0]) + copy(commitmentsSerialized[i*fr.Bytes:], res[0].Marshal()) } + } - if res, err := solveCommitmentWire(&vk.CommitmentInfo, &proof.Commitment, publicCommitted); err == nil { - publicWitness = append(publicWitness, res) + if folded, err := pedersen.FoldCommitments(proof.Commitments, commitmentsSerialized); err != nil { + return err + } else { + if err = vk.CommitmentKey.Verify(folded, proof.CommitmentPok); err != nil { + return err } } @@ -72,8 +84,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { } kSum.AddMixed(&vk.G1.K[0]) - if vk.CommitmentInfo.Is() { - kSum.AddMixed(&proof.Commitment) + for i := range proof.Commitments { + kSum.AddMixed(&proof.Commitments[i]) } var kSumAff curve.G1Affine diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl new file mode 100644 index 0000000000..89b75eb44f --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl @@ -0,0 +1,201 @@ +import ( + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + {{- template "import_fft" . }} + "github.com/consensys/gnark/internal/utils" +) + + + +func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { + coeffs := make([]curve.G1Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { + coeffs := make([]curve.G2Affine, size) + copy(coeffs, powers[:size]) + domain := fft.NewDomain(uint64(size)) + numCPU := uint64(runtime.NumCPU()) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) + + difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + bitReverse(coeffs) + + var invBigint big.Int + domain.CardinalityInv.BigInt(&invBigint) + + utils.Parallelize(size, func(start, end int) { + for i := start; i < end; i++ { + coeffs[i].ScalarMultiplication(&coeffs[i], &invBigint) + } + }) + return coeffs +} + +func butterflyG1(a *curve.G1Affine, b *curve.G1Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +func butterflyG2(a *curve.G2Affine, b *curve.G2Affine) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G1(a []curve.G1Affine, twiddles [][]fr.Element, stage int) { + butterflyG1(&a[0], &a[4]) + butterflyG1(&a[1], &a[5]) + butterflyG1(&a[2], &a[6]) + butterflyG1(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[2]) + butterflyG1(&a[1], &a[3]) + butterflyG1(&a[4], &a[6]) + butterflyG1(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG1(&a[0], &a[1]) + butterflyG1(&a[2], &a[3]) + butterflyG1(&a[4], &a[5]) + butterflyG1(&a[6], &a[7]) +} + +// kerDIF8 is a kernel that process a FFT of size 8 +func kerDIF8G2(a []curve.G2Affine, twiddles [][]fr.Element, stage int) { + butterflyG2(&a[0], &a[4]) + butterflyG2(&a[1], &a[5]) + butterflyG2(&a[2], &a[6]) + butterflyG2(&a[3], &a[7]) + + var twiddle big.Int + twiddles[stage+0][1].BigInt(&twiddle) + a[5].ScalarMultiplication(&a[5], &twiddle) + twiddles[stage+0][2].BigInt(&twiddle) + a[6].ScalarMultiplication(&a[6], &twiddle) + twiddles[stage+0][3].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[2]) + butterflyG2(&a[1], &a[3]) + butterflyG2(&a[4], &a[6]) + butterflyG2(&a[5], &a[7]) + twiddles[stage+1][1].BigInt(&twiddle) + a[3].ScalarMultiplication(&a[3], &twiddle) + twiddles[stage+1][1].BigInt(&twiddle) + a[7].ScalarMultiplication(&a[7], &twiddle) + butterflyG2(&a[0], &a[1]) + butterflyG2(&a[2], &a[3]) + butterflyG2(&a[4], &a[5]) + butterflyG2(&a[6], &a[7]) +} + +func difFFTG1(a []curve.G1Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G1(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG1(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG1(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} +func difFFTG2(a []curve.G2Affine, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if n == 8 { + kerDIF8G2(a, twiddles, stage) + return + } + m := n >> 1 + + butterflyG2(&a[0], &a[m]) + + var twiddle big.Int + for i := 1; i < m; i++ { + butterflyG2(&a[i], &a[i+m]) + twiddles[stage][i].BigInt(&twiddle) + a[i+m].ScalarMultiplication(&a[i+m], &twiddle) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTG2(a[m:n], twiddles, nextStage, maxSplits, chDone) + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + <-chDone + } else { + difFFTG2(a[0:m], twiddles, nextStage, maxSplits, nil) + difFFTG2(a[m:n], twiddles, nextStage, maxSplits, nil) + } +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal.go.tmpl new file mode 100644 index 0000000000..3af994ae04 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal.go.tmpl @@ -0,0 +1,165 @@ +import ( + "io" + + + {{- template "import_curve" . }} +) + +// WriteTo implements io.WriterTo +func (phase1 *Phase1) WriteTo(writer io.Writer) (int64, error) { + n, err := phase1.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase1.Hash) + return int64(nBytes) + n, err +} + +func (phase1 *Phase1) writeTo(writer io.Writer) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + phase1.Parameters.G1.Tau, + phase1.Parameters.G1.AlphaTau, + phase1.Parameters.G1.BetaTau, + phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + enc := curve.NewEncoder(writer) + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (phase1 *Phase1) ReadFrom(reader io.Reader) (int64, error) { + toEncode := []interface{}{ + &phase1.PublicKeys.Tau.SG, + &phase1.PublicKeys.Tau.SXG, + &phase1.PublicKeys.Tau.XR, + &phase1.PublicKeys.Alpha.SG, + &phase1.PublicKeys.Alpha.SXG, + &phase1.PublicKeys.Alpha.XR, + &phase1.PublicKeys.Beta.SG, + &phase1.PublicKeys.Beta.SXG, + &phase1.PublicKeys.Beta.XR, + &phase1.Parameters.G1.Tau, + &phase1.Parameters.G1.AlphaTau, + &phase1.Parameters.G1.BetaTau, + &phase1.Parameters.G2.Tau, + &phase1.Parameters.G2.Beta, + } + + dec := curve.NewDecoder(reader) + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + phase1.Hash = make([]byte, 32) + nBytes, err := reader.Read(phase1.Hash) + return dec.BytesRead() + int64(nBytes), err +} + +// WriteTo implements io.WriterTo +func (phase2 *Phase2) WriteTo(writer io.Writer) (int64, error) { + n, err := phase2.writeTo(writer) + if err != nil { + return n, err + } + nBytes, err := writer.Write(phase2.Hash) + return int64(nBytes) + n, err +} + +func (c *Phase2) writeTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + c.Parameters.G1.L, + c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.PublicKey.SG, + &c.PublicKey.SXG, + &c.PublicKey.XR, + &c.Parameters.G1.Delta, + &c.Parameters.G1.L, + &c.Parameters.G1.Z, + &c.Parameters.G2.Delta, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + c.Hash = make([]byte, 32) + n, err := reader.Read(c.Hash) + return int64(n) + dec.BytesRead(), err + +} + +// WriteTo implements io.WriterTo +func (c *Phase2Evaluations) WriteTo(writer io.Writer) (int64, error) { + enc := curve.NewEncoder(writer) + toEncode := []interface{}{ + c.G1.A, + c.G1.B, + c.G2.B, + } + + for _, v := range toEncode { + if err := enc.Encode(v); err != nil { + return enc.BytesWritten(), err + } + } + + return enc.BytesWritten(), nil +} + +// ReadFrom implements io.ReaderFrom +func (c *Phase2Evaluations) ReadFrom(reader io.Reader) (int64, error) { + dec := curve.NewDecoder(reader) + toEncode := []interface{}{ + &c.G1.A, + &c.G1.B, + &c.G2.B, + } + + for _, v := range toEncode { + if err := dec.Decode(v); err != nil { + return dec.BytesRead(), err + } + } + + return dec.BytesRead(), nil +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal_test.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal_test.go.tmpl new file mode 100644 index 0000000000..803b1863d9 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/marshal_test.go.tmpl @@ -0,0 +1,63 @@ +import ( + "bytes" + "io" + "reflect" + "testing" + + + {{- template "import_curve" . }} + {{- template "import_backend_cs" . }} + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" +) + +func TestContributionSerialization(t *testing.T) { + assert := require.New(t) + + // Phase 1 + srs1 := InitPhase1(9) + srs1.Contribute() + { + var reconstructed Phase1 + roundTripCheck(t, &srs1, &reconstructed) + } + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + r1cs := ccs.(*cs.R1CS) + + // Phase 2 + srs2, _ := InitPhase2(r1cs, &srs1) + srs2.Contribute() + + { + var reconstructed Phase2 + roundTripCheck(t, &srs2, &reconstructed) + } +} + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + t.Helper() + + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase1.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase1.go.tmpl new file mode 100644 index 0000000000..7c8c6fa540 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase1.go.tmpl @@ -0,0 +1,187 @@ +import ( + "crypto/sha256" + "errors" + "math" + "math/big" + + + {{- template "import_fr" . }} + {{- template "import_curve" . }} +) + +// Phase1 represents the Phase1 of the MPC described in +// https://eprint.iacr.org/2017/1050.pdf +// +// Also known as "Powers of Tau" +type Phase1 struct { + Parameters struct { + G1 struct { + Tau []curve.G1Affine // {[τ⁰]₁, [τ¹]₁, [τ²]₁, …, [τ²ⁿ⁻²]₁} + AlphaTau []curve.G1Affine // {α[τ⁰]₁, α[τ¹]₁, α[τ²]₁, …, α[τⁿ⁻¹]₁} + BetaTau []curve.G1Affine // {β[τ⁰]₁, β[τ¹]₁, β[τ²]₁, …, β[τⁿ⁻¹]₁} + } + G2 struct { + Tau []curve.G2Affine // {[τ⁰]₂, [τ¹]₂, [τ²]₂, …, [τⁿ⁻¹]₂} + Beta curve.G2Affine // [β]₂ + } + } + PublicKeys struct { + Tau, Alpha, Beta PublicKey + } + Hash []byte // sha256 hash +} + +// InitPhase1 initialize phase 1 of the MPC. This is called once by the coordinator before +// any randomness contribution is made (see Contribute()). +func InitPhase1(power int) (phase1 Phase1) { + N := int(math.Pow(2, float64(power))) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetOne() + alpha.SetOne() + beta.SetOne() + phase1.PublicKeys.Tau = newPublicKey(tau, nil, 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, nil, 2) + phase1.PublicKeys.Beta = newPublicKey(beta, nil, 3) + + // First contribution use generators + _, _, g1, g2 := curve.Generators() + phase1.Parameters.G2.Beta.Set(&g2) + phase1.Parameters.G1.Tau = make([]curve.G1Affine, 2*N-1) + phase1.Parameters.G2.Tau = make([]curve.G2Affine, N) + phase1.Parameters.G1.AlphaTau = make([]curve.G1Affine, N) + phase1.Parameters.G1.BetaTau = make([]curve.G1Affine, N) + for i := 0; i < len(phase1.Parameters.G1.Tau); i++ { + phase1.Parameters.G1.Tau[i].Set(&g1) + } + for i := 0; i < len(phase1.Parameters.G2.Tau); i++ { + phase1.Parameters.G2.Tau[i].Set(&g2) + phase1.Parameters.G1.AlphaTau[i].Set(&g1) + phase1.Parameters.G1.BetaTau[i].Set(&g1) + } + + phase1.Parameters.G2.Beta.Set(&g2) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() + + return +} + +// Contribute contributes randomness to the phase1 object. This mutates phase1. +func (phase1 *Phase1) Contribute() { + N := len(phase1.Parameters.G2.Tau) + + // Generate key pairs + var tau, alpha, beta fr.Element + tau.SetRandom() + alpha.SetRandom() + beta.SetRandom() + phase1.PublicKeys.Tau = newPublicKey(tau, phase1.Hash[:], 1) + phase1.PublicKeys.Alpha = newPublicKey(alpha, phase1.Hash[:], 2) + phase1.PublicKeys.Beta = newPublicKey(beta, phase1.Hash[:], 3) + + // Compute powers of τ, ατ, and βτ + taus := powers(tau, 2*N-1) + alphaTau := make([]fr.Element, N) + betaTau := make([]fr.Element, N) + for i := 0; i < N; i++ { + alphaTau[i].Mul(&taus[i], &alpha) + betaTau[i].Mul(&taus[i], &beta) + } + + // Update using previous parameters + // TODO @gbotrel working with jacobian points here will help with perf. + scaleG1InPlace(phase1.Parameters.G1.Tau, taus) + scaleG2InPlace(phase1.Parameters.G2.Tau, taus[0:N]) + scaleG1InPlace(phase1.Parameters.G1.AlphaTau, alphaTau) + scaleG1InPlace(phase1.Parameters.G1.BetaTau, betaTau) + var betaBI big.Int + beta.BigInt(&betaBI) + phase1.Parameters.G2.Beta.ScalarMultiplication(&phase1.Parameters.G2.Beta, &betaBI) + + // Compute hash of Contribution + phase1.Hash = phase1.hash() +} + +func VerifyPhase1(c0, c1 *Phase1, c ...*Phase1) error { + contribs := append([]*Phase1{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase1(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +// verifyPhase1 checks that a contribution is based on a known previous Phase1 state. +func verifyPhase1(current, contribution *Phase1) error { + // Compute R for τ, α, β + tauR := genR(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, current.Hash[:], 1) + alphaR := genR(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, current.Hash[:], 2) + betaR := genR(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, current.Hash[:], 3) + + // Check for knowledge of toxic parameters + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.PublicKeys.Tau.XR, tauR) { + return errors.New("couldn't verify public key of τ") + } + if !sameRatio(contribution.PublicKeys.Alpha.SG, contribution.PublicKeys.Alpha.SXG, contribution.PublicKeys.Alpha.XR, alphaR) { + return errors.New("couldn't verify public key of α") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.PublicKeys.Beta.XR, betaR) { + return errors.New("couldn't verify public key of β") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Tau[1], current.Parameters.G1.Tau[1], tauR, contribution.PublicKeys.Tau.XR) { + return errors.New("couldn't verify that [τ]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.AlphaTau[0], current.Parameters.G1.AlphaTau[0], alphaR, contribution.PublicKeys.Alpha.XR) { + return errors.New("couldn't verify that [α]₁ is based on previous contribution") + } + if !sameRatio(contribution.Parameters.G1.BetaTau[0], current.Parameters.G1.BetaTau[0], betaR, contribution.PublicKeys.Beta.XR) { + return errors.New("couldn't verify that [β]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Tau.SG, contribution.PublicKeys.Tau.SXG, contribution.Parameters.G2.Tau[1], current.Parameters.G2.Tau[1]) { + return errors.New("couldn't verify that [τ]₂ is based on previous contribution") + } + if !sameRatio(contribution.PublicKeys.Beta.SG, contribution.PublicKeys.Beta.SXG, contribution.Parameters.G2.Beta, current.Parameters.G2.Beta) { + return errors.New("couldn't verify that [β]₂ is based on previous contribution") + } + + // Check for valid updates using powers of τ + _, _, g1, g2 := curve.Generators() + tauL1, tauL2 := linearCombinationG1(contribution.Parameters.G1.Tau) + if !sameRatio(tauL1, tauL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of τ in G₁") + } + alphaL1, alphaL2 := linearCombinationG1(contribution.Parameters.G1.AlphaTau) + if !sameRatio(alphaL1, alphaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + betaL1, betaL2 := linearCombinationG1(contribution.Parameters.G1.BetaTau) + if !sameRatio(betaL1, betaL2, contribution.Parameters.G2.Tau[1], g2) { + return errors.New("couldn't verify valid powers of α(τ) in G₁") + } + tau2L1, tau2L2 := linearCombinationG2(contribution.Parameters.G2.Tau) + if !sameRatio(contribution.Parameters.G1.Tau[1], g1, tau2L1, tau2L2) { + return errors.New("couldn't verify valid powers of τ in G₂") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (phase1 *Phase1) hash() []byte { + sha := sha256.New() + phase1.writeTo(sha) + return sha.Sum(nil) +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase2.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase2.go.tmpl new file mode 100644 index 0000000000..0afb32db79 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/phase2.go.tmpl @@ -0,0 +1,248 @@ +import ( + "crypto/sha256" + "errors" + "math/big" + + "github.com/consensys/gnark/constraint" + + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + {{- template "import_backend_cs" . }} +) + +type Phase2Evaluations struct { + G1 struct { + A, B, VKK []curve.G1Affine + } + G2 struct { + B []curve.G2Affine + } +} + +type Phase2 struct { + Parameters struct { + G1 struct { + Delta curve.G1Affine + L, Z []curve.G1Affine + } + G2 struct { + Delta curve.G2Affine + } + } + PublicKey PublicKey + Hash []byte +} + +func InitPhase2(r1cs *cs.R1CS, srs1 *Phase1) (Phase2, Phase2Evaluations) { + srs := srs1.Parameters + size := len(srs.G1.AlphaTau) + if size < r1cs.GetNbConstraints() { + panic("Number of constraints is larger than expected") + } + + c2 := Phase2{} + + accumulateG1 := func(res *curve.G1Affine, t constraint.Term, value *curve.G1Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G1Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + accumulateG2 := func(res *curve.G2Affine, t constraint.Term, value *curve.G2Affine) { + cID := t.CoeffID() + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + res.Add(res, value) + case constraint.CoeffIdMinusOne: + res.Sub(res, value) + case constraint.CoeffIdTwo: + res.Add(res, value).Add(res, value) + default: + var tmp curve.G2Affine + var vBi big.Int + r1cs.Coefficients[cID].BigInt(&vBi) + tmp.ScalarMultiplication(value, &vBi) + res.Add(res, &tmp) + } + } + + // Prepare Lagrange coefficients of [τ...]₁, [τ...]₂, [ατ...]₁, [βτ...]₁ + coeffTau1 := lagrangeCoeffsG1(srs.G1.Tau, size) + coeffTau2 := lagrangeCoeffsG2(srs.G2.Tau, size) + coeffAlphaTau1 := lagrangeCoeffsG1(srs.G1.AlphaTau, size) + coeffBetaTau1 := lagrangeCoeffsG1(srs.G1.BetaTau, size) + + internal, secret, public := r1cs.GetNbVariables() + nWires := internal + secret + public + var evals Phase2Evaluations + evals.G1.A = make([]curve.G1Affine, nWires) + evals.G1.B = make([]curve.G1Affine, nWires) + evals.G2.B = make([]curve.G2Affine, nWires) + bA := make([]curve.G1Affine, nWires) + aB := make([]curve.G1Affine, nWires) + C := make([]curve.G1Affine, nWires) + + // TODO @gbotrel use constraint iterator when available. + + i := 0 + it := r1cs.GetR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + // A + for _, t := range c.L { + accumulateG1(&evals.G1.A[t.WireID()], t, &coeffTau1[i]) + accumulateG1(&bA[t.WireID()], t, &coeffBetaTau1[i]) + } + // B + for _, t := range c.R { + accumulateG1(&evals.G1.B[t.WireID()], t, &coeffTau1[i]) + accumulateG2(&evals.G2.B[t.WireID()], t, &coeffTau2[i]) + accumulateG1(&aB[t.WireID()], t, &coeffAlphaTau1[i]) + } + // C + for _, t := range c.O { + accumulateG1(&C[t.WireID()], t, &coeffTau1[i]) + } + i++ + } + + // Prepare default contribution + _, _, g1, g2 := curve.Generators() + c2.Parameters.G1.Delta = g1 + c2.Parameters.G2.Delta = g2 + + // Build Z in PK as τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + // τⁱ(τⁿ - 1) = τ⁽ⁱ⁺ⁿ⁾ - τⁱ for i ∈ [0, n-2] + n := len(srs.G1.AlphaTau) + c2.Parameters.G1.Z = make([]curve.G1Affine, n) + for i := 0; i < n-1; i++ { + c2.Parameters.G1.Z[i].Sub(&srs.G1.Tau[i+n], &srs.G1.Tau[i]) + } + bitReverse(c2.Parameters.G1.Z) + c2.Parameters.G1.Z = c2.Parameters.G1.Z[:n-1] + + // Evaluate L + nPrivate := internal + secret + c2.Parameters.G1.L = make([]curve.G1Affine, nPrivate) + evals.G1.VKK = make([]curve.G1Affine, public) + offset := public + for i := 0; i < nWires; i++ { + var tmp curve.G1Affine + tmp.Add(&bA[i], &aB[i]) + tmp.Add(&tmp, &C[i]) + if i < public { + evals.G1.VKK[i].Set(&tmp) + } else { + c2.Parameters.G1.L[i-offset].Set(&tmp) + } + } + // Set δ public key + var delta fr.Element + delta.SetOne() + c2.PublicKey = newPublicKey(delta, nil, 1) + + // Hash initial contribution + c2.Hash = c2.hash() + return c2, evals +} + +func (c *Phase2) Contribute() { + // Sample toxic δ + var delta, deltaInv fr.Element + var deltaBI, deltaInvBI big.Int + delta.SetRandom() + deltaInv.Inverse(&delta) + + delta.BigInt(&deltaBI) + deltaInv.BigInt(&deltaInvBI) + + // Set δ public key + c.PublicKey = newPublicKey(delta, c.Hash, 1) + + // Update δ + c.Parameters.G1.Delta.ScalarMultiplication(&c.Parameters.G1.Delta, &deltaBI) + c.Parameters.G2.Delta.ScalarMultiplication(&c.Parameters.G2.Delta, &deltaBI) + + // Update Z using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.Z); i++ { + c.Parameters.G1.Z[i].ScalarMultiplication(&c.Parameters.G1.Z[i], &deltaInvBI) + } + + // Update L using δ⁻¹ + for i := 0; i < len(c.Parameters.G1.L); i++ { + c.Parameters.G1.L[i].ScalarMultiplication(&c.Parameters.G1.L[i], &deltaInvBI) + } + + // 4. Hash contribution + c.Hash = c.hash() +} + +func VerifyPhase2(c0, c1 *Phase2, c ...*Phase2) error { + contribs := append([]*Phase2{c0, c1}, c...) + for i := 0; i < len(contribs)-1; i++ { + if err := verifyPhase2(contribs[i], contribs[i+1]); err != nil { + return err + } + } + return nil +} + +func verifyPhase2(current, contribution *Phase2) error { + // Compute R for δ + deltaR := genR(contribution.PublicKey.SG, contribution.PublicKey.SXG, current.Hash[:], 1) + + // Check for knowledge of δ + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.PublicKey.XR, deltaR) { + return errors.New("couldn't verify knowledge of δ") + } + + // Check for valid updates using previous parameters + if !sameRatio(contribution.Parameters.G1.Delta, current.Parameters.G1.Delta, deltaR, contribution.PublicKey.XR) { + return errors.New("couldn't verify that [δ]₁ is based on previous contribution") + } + if !sameRatio(contribution.PublicKey.SG, contribution.PublicKey.SXG, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify that [δ]₂ is based on previous contribution") + } + + // Check for valid updates of L and Z using + L, prevL := merge(contribution.Parameters.G1.L, current.Parameters.G1.L) + if !sameRatio(L, prevL, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + Z, prevZ := merge(contribution.Parameters.G1.Z, current.Parameters.G1.Z) + if !sameRatio(Z, prevZ, contribution.Parameters.G2.Delta, current.Parameters.G2.Delta) { + return errors.New("couldn't verify valid updates of L using δ⁻¹") + } + + // Check hash of the contribution + h := contribution.hash() + for i := 0; i < len(h); i++ { + if h[i] != contribution.Hash[i] { + return errors.New("couldn't verify hash of contribution") + } + } + + return nil +} + +func (c *Phase2) hash() []byte { + sha := sha256.New() + c.writeTo(sha) + return sha.Sum(nil) +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup.go.tmpl new file mode 100644 index 0000000000..e60410b467 --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup.go.tmpl @@ -0,0 +1,80 @@ +import ( + groth16 "github.com/consensys/gnark/backend/groth16/{{toLower .Curve}}" + + {{- template "import_curve" . }} + {{- template "import_fft" . }} +) + +func ExtractKeys(srs1 *Phase1, srs2 *Phase2, evals *Phase2Evaluations, nConstraints int) (pk groth16.ProvingKey, vk groth16.VerifyingKey) { + _, _, _, g2 := curve.Generators() + + // Initialize PK + pk.Domain = *fft.NewDomain(uint64(nConstraints)) + pk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + pk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + pk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + pk.G1.Z = srs2.Parameters.G1.Z + bitReverse(pk.G1.Z) + + pk.G1.K = srs2.Parameters.G1.L + pk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + pk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + + // Filter out infinity points + nWires := len(evals.G1.A) + pk.InfinityA = make([]bool, nWires) + A := make([]curve.G1Affine, nWires) + j := 0 + for i, e := range evals.G1.A { + if e.IsInfinity() { + pk.InfinityA[i] = true + continue + } + A[j] = evals.G1.A[i] + j++ + } + pk.G1.A = A[:j] + pk.NbInfinityA = uint64(nWires - j) + + pk.InfinityB = make([]bool, nWires) + B := make([]curve.G1Affine, nWires) + j = 0 + for i, e := range evals.G1.B { + if e.IsInfinity() { + pk.InfinityB[i] = true + continue + } + B[j] = evals.G1.B[i] + j++ + } + pk.G1.B = B[:j] + pk.NbInfinityB = uint64(nWires - j) + + B2 := make([]curve.G2Affine, nWires) + j = 0 + for i, e := range evals.G2.B { + if e.IsInfinity() { + // pk.InfinityB[i] = true should be the same as in B + continue + } + B2[j] = evals.G2.B[i] + j++ + } + pk.G2.B = B2[:j] + + // Initialize VK + vk.G1.Alpha.Set(&srs1.Parameters.G1.AlphaTau[0]) + vk.G1.Beta.Set(&srs1.Parameters.G1.BetaTau[0]) + vk.G1.Delta.Set(&srs2.Parameters.G1.Delta) + vk.G2.Beta.Set(&srs1.Parameters.G2.Beta) + vk.G2.Delta.Set(&srs2.Parameters.G2.Delta) + vk.G2.Gamma.Set(&g2) + vk.G1.K = evals.G1.VKK + + // sets e, -[δ]2, -[γ]2 + if err := vk.Precompute(); err != nil { + panic(err) + } + + return pk, vk +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup_test.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup_test.go.tmpl new file mode 100644 index 0000000000..a36c0c1b9d --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/setup_test.go.tmpl @@ -0,0 +1,184 @@ +import ( + "testing" + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + {{- template "import_backend_cs" . }} + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/stretchr/testify/require" + + native_mimc "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/mimc" +) + +func TestSetupCircuit(t *testing.T) { + {{- if ne (toLower .Curve) "bn254" }} + if testing.Short() { + t.Skip() + } + {{- end}} + const ( + nContributionsPhase1 = 3 + nContributionsPhase2 = 3 + power = 9 + ) + + assert := require.New(t) + + srs1 := InitPhase1(power) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase1; i++ { + // we clone test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs1.clone() + + srs1.Contribute() + assert.NoError(VerifyPhase1(&prev, &srs1)) + } + + // Compile the circuit + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + assert.NoError(err) + + var evals Phase2Evaluations + r1cs := ccs.(*cs.R1CS) + + // Prepare for phase-2 + srs2, evals := InitPhase2(r1cs, &srs1) + + // Make and verify contributions for phase1 + for i := 1; i < nContributionsPhase2; i++ { + // we clone for test purposes; but in practice, participant will receive a []byte, deserialize it, + // add his contribution and send back to coordinator. + prev := srs2.clone() + + srs2.Contribute() + assert.NoError(VerifyPhase2(&prev, &srs2)) + } + + // Extract the proving and verifying keys + pk, vk := ExtractKeys(&srs1, &srs2, &evals, ccs.GetNbConstraints()) + + // Build the witness + var preImage, hash fr.Element + { + m := native_mimc.NewMiMC() + m.Write(preImage.Marshal()) + hash.SetBytes(m.Sum(nil)) + } + + witness, err := frontend.NewWitness(&Circuit{PreImage: preImage, Hash: hash}, curve.ID.ScalarField()) + assert.NoError(err) + + pubWitness, err := witness.Public() + assert.NoError(err) + + // groth16: ensure proof is verified + proof, err := groth16.Prove(ccs, &pk, witness) + assert.NoError(err) + + err = groth16.Verify(proof, &vk, pubWitness) + assert.NoError(err) +} + +func BenchmarkPhase1(b *testing.B) { + const power = 14 + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = InitPhase1(power) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs1 := InitPhase1(power) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs1.Contribute() + } + }) + +} + +func BenchmarkPhase2(b *testing.B) { + const power = 14 + srs1 := InitPhase1(power) + srs1.Contribute() + + var myCircuit Circuit + ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) + if err != nil { + b.Fatal(err) + } + + r1cs := ccs.(*cs.R1CS) + + b.Run("init", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = InitPhase2(r1cs, &srs1) + } + }) + + b.Run("contrib", func(b *testing.B) { + srs2, _ := InitPhase2(r1cs, &srs1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + srs2.Contribute() + } + }) + +} + +// Circuit defines a pre-image knowledge proof +// mimc(secret preImage) = public hash +type Circuit struct { + PreImage frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +// Define declares the circuit's constraints +// Hash = mimc(PreImage) +func (circuit *Circuit) Define(api frontend.API) error { + // hash function + mimc, _ := mimc.NewMiMC(api) + + // specify constraints + mimc.Write(circuit.PreImage) + api.AssertIsEqual(circuit.Hash, mimc.Sum()) + + return nil +} + +func (phase1 *Phase1) clone() Phase1 { + r := Phase1{} + r.Parameters.G1.Tau = append(r.Parameters.G1.Tau, phase1.Parameters.G1.Tau...) + r.Parameters.G1.AlphaTau = append(r.Parameters.G1.AlphaTau, phase1.Parameters.G1.AlphaTau...) + r.Parameters.G1.BetaTau = append(r.Parameters.G1.BetaTau, phase1.Parameters.G1.BetaTau...) + + r.Parameters.G2.Tau = append(r.Parameters.G2.Tau, phase1.Parameters.G2.Tau...) + r.Parameters.G2.Beta = phase1.Parameters.G2.Beta + + r.PublicKeys = phase1.PublicKeys + r.Hash = append(r.Hash, phase1.Hash...) + + return r +} + +func (phase2 *Phase2) clone() Phase2 { + r := Phase2{} + r.Parameters.G1.Delta = phase2.Parameters.G1.Delta + r.Parameters.G1.L = append(r.Parameters.G1.L, phase2.Parameters.G1.L...) + r.Parameters.G1.Z = append(r.Parameters.G1.Z, phase2.Parameters.G1.Z...) + r.Parameters.G2.Delta = phase2.Parameters.G2.Delta + r.PublicKey = phase2.PublicKey + r.Hash = append(r.Hash, phase2.Hash...) + + return r +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/utils.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/utils.go.tmpl new file mode 100644 index 0000000000..f7c67c036d --- /dev/null +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/utils.go.tmpl @@ -0,0 +1,153 @@ +import ( + "bytes" + "math/big" + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc" + + {{- template "import_fr" . }} + {{- template "import_curve" . }} + "github.com/consensys/gnark/internal/utils" +) + +type PublicKey struct { + SG curve.G1Affine + SXG curve.G1Affine + XR curve.G2Affine +} + +func newPublicKey(x fr.Element, challenge []byte, dst byte) PublicKey { + var pk PublicKey + _, _, g1, _ := curve.Generators() + + var s fr.Element + var sBi big.Int + s.SetRandom() + s.BigInt(&sBi) + pk.SG.ScalarMultiplication(&g1, &sBi) + + // compute x*sG1 + var xBi big.Int + x.BigInt(&xBi) + pk.SXG.ScalarMultiplication(&pk.SG, &xBi) + + // generate R based on sG1, sxG1, challenge, and domain separation tag (tau, alpha or beta) + R := genR(pk.SG, pk.SXG, challenge, dst) + + // compute x*spG2 + pk.XR.ScalarMultiplication(&R, &xBi) + return pk +} + +func bitReverse[T any](a []T) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + irev := bits.Reverse64(i) >> nn + if irev > i { + a[i], a[irev] = a[irev], a[i] + } + } +} + +// Returns [1, a, a², ..., aⁿ⁻¹ ] in Montgomery form +func powers(a fr.Element, n int) []fr.Element { + result := make([]fr.Element, n) + result[0] = fr.NewElement(1) + for i := 1; i < n; i++ { + result[i].Mul(&result[i-1], &a) + } + return result +} + +// Returns [aᵢAᵢ, ...] in G1 +func scaleG1InPlace(A []curve.G1Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Returns [aᵢAᵢ, ...] in G2 +func scaleG2InPlace(A []curve.G2Affine, a []fr.Element) { + utils.Parallelize(len(A), func(start, end int) { + var tmp big.Int + for i := start; i < end; i++ { + a[i].BigInt(&tmp) + A[i].ScalarMultiplication(&A[i], &tmp) + } + }) +} + +// Check e(a₁, a₂) = e(b₁, b₂) +func sameRatio(a1, b1 curve.G1Affine, a2, b2 curve.G2Affine) bool { + if !a1.IsInSubGroup() || !b1.IsInSubGroup() || !a2.IsInSubGroup() || !b2.IsInSubGroup() { + panic("invalid point not in subgroup") + } + var na2 curve.G2Affine + na2.Neg(&a2) + res, err := curve.PairingCheck( + []curve.G1Affine{a1, b1}, + []curve.G2Affine{na2, b2}) + if err != nil { + panic(err) + } + return res +} + +// returns a = ∑ rᵢAᵢ, b = ∑ rᵢBᵢ +func merge(A, B []curve.G1Affine) (a, b curve.G1Affine) { + nc := runtime.NumCPU() + r := make([]fr.Element, len(A)) + for i := 0; i < len(A); i++ { + r[i].SetRandom() + } + a.MultiExp(A, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + b.MultiExp(B, r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G1 +func linearCombinationG1(A []curve.G1Affine) (L1, L2 curve.G1Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// L1 = ∑ rᵢAᵢ, L2 = ∑ rᵢAᵢ₊₁ in G2 +func linearCombinationG2(A []curve.G2Affine) (L1, L2 curve.G2Affine) { + nc := runtime.NumCPU() + n := len(A) + r := make([]fr.Element, n-1) + for i := 0; i < n-1; i++ { + r[i].SetRandom() + } + L1.MultiExp(A[:n-1], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + L2.MultiExp(A[1:], r, ecc.MultiExpConfig{NbTasks: nc / 2}) + return +} + +// Generate R in G₂ as Hash(gˢ, gˢˣ, challenge, dst) +func genR(sG1, sxG1 curve.G1Affine, challenge []byte, dst byte) curve.G2Affine { + var buf bytes.Buffer + buf.Grow(len(challenge) + curve.SizeOfG1AffineUncompressed*2) + buf.Write(sG1.Marshal()) + buf.Write(sxG1.Marshal()) + buf.Write(challenge) + spG2, err := curve.HashToG2(buf.Bytes(), []byte{dst}) + if err != nil { + panic(err) + } + return spG2 +} diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl index 94d202d2b4..cc94c21ef1 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.commitment.go.tmpl @@ -1,4 +1,7 @@ import ( + "testing" + "fmt" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/witness" @@ -6,7 +9,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/assert" - "testing" ) type singleSecretCommittedCircuit struct { @@ -15,7 +17,11 @@ type singleSecretCommittedCircuit struct { func (c *singleSecretCommittedCircuit) Define(api frontend.API) error { api.AssertIsEqual(c.One, 1) - commit, err := api.Compiler().Commit(c.One) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One) if err != nil { return err } @@ -101,8 +107,11 @@ type oneSecretOnePublicCommittedCircuit struct { } func (c *oneSecretOnePublicCommittedCircuit) Define(api frontend.API) error { - - commit, err := api.Compiler().Commit(c.One, c.Two) + commitCompiler, ok := api.Compiler().(frontend.Committer) + if !ok { + return fmt.Errorf("compiler does not commit") + } + commit, err := commitCompiler.Commit(c.One, c.Two) if err != nil { return err } diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl index e5ad79d370..1366e7ac92 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl @@ -2,7 +2,10 @@ import ( {{ template "import_curve" . }} {{ template "import_fft" . }} - + {{ template "import_pedersen" . }} + "github.com/consensys/gnark/backend/groth16/internal/test_utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "bytes" "math/big" @@ -10,6 +13,7 @@ import ( "github.com/leanovate/gopter" "github.com/leanovate/gopter/prop" + "github.com/leanovate/gopter/gen" "testing" ) @@ -73,13 +77,9 @@ func TestProofSerialization(t *testing.T) { func TestVerifyingKeySerialization(t *testing.T) { - parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 - - properties := gopter.NewProperties(parameters) - properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + roundTrip := func(withCommitment bool) func(curve.G1Affine, curve.G2Affine) bool { + return func(p1 curve.G1Affine, p2 curve.G2Affine) bool { var vk, vkCompressed, vkRaw VerifyingKey // create a random vk @@ -107,6 +107,20 @@ func TestVerifyingKeySerialization(t *testing.T) { vk.G1.K[i] = p1 } + if withCommitment { + vk.PublicAndCommitmentCommitted = test_utils.Random2DIntSlice(5, 10) // TODO: Use gopter randomization + bases := make([][]curve.G1Affine, len(vk.PublicAndCommitmentCommitted)) + elem := p1 + for i := 0; i < len(vk.PublicAndCommitmentCommitted); i++ { + bases[i] = make([]curve.G1Affine, len(vk.PublicAndCommitmentCommitted[i])) + for j := range bases[i] { + bases[i][j] = elem + elem.Add(&elem, &p1) + } + } + _, vk.CommitmentKey, err = pedersen.Setup(bases...) + assert.NoError(t, err) + } var bufCompressed bytes.Buffer written, err := vk.WriteTo(&bufCompressed) @@ -145,7 +159,22 @@ func TestVerifyingKeySerialization(t *testing.T) { } return reflect.DeepEqual(&vk, &vkCompressed) && reflect.DeepEqual(&vk, &vkRaw) - }, + } + } + + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + properties.Property("VerifyingKey -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(false), + GenG1(), + GenG2(), + )) + + properties.Property("VerifyingKey (with commitments) -> writer -> reader -> VerifyingKey should stay constant", prop.ForAll( + roundTrip(true), GenG1(), GenG2(), )) @@ -162,7 +191,7 @@ func TestProvingKeySerialization(t *testing.T) { properties := gopter.NewProperties(parameters) properties.Property("ProvingKey -> writer -> reader -> ProvingKey should stay constant", prop.ForAll( - func(p1 curve.G1Affine, p2 curve.G2Affine) bool { + func(p1 curve.G1Affine, p2 curve.G2Affine, nbCommitment int) bool { var pk, pkCompressed, pkRaw ProvingKey // create a random pk @@ -189,7 +218,20 @@ func TestProvingKeySerialization(t *testing.T) { pk.NbInfinityA = 1 pk.InfinityA = make([]bool, nbWires) pk.InfinityB = make([]bool, nbWires) - pk.InfinityA[2] = true + pk.InfinityA[2] = true + + pedersenBasis := make([]curve.G1Affine, nbCommitment) + pedersenBases := make([][]curve.G1Affine, nbCommitment) + pk.CommitmentKeys = make([]pedersen.ProvingKey, nbCommitment) + for i := range pedersenBasis { + pedersenBasis[i] = p1 + pedersenBases[i] = pedersenBasis[:i+1] + } + { + var err error + pk.CommitmentKeys, _, err = pedersen.Setup(pedersenBases...) + require.NoError(t, err) + } var bufCompressed bytes.Buffer written, err := pk.WriteTo(&bufCompressed) @@ -231,6 +273,7 @@ func TestProvingKeySerialization(t *testing.T) { }, GenG1(), GenG2(), + gen.IntRange(0, 2), )) properties.TestingRun(t, gopter.ConsoleReporter(false)) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl index 760f22b3d1..6639c4434e 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl @@ -1,11 +1,13 @@ import ( {{ template "import_curve" . }} {{ template "import_fr" . }} + {{ template "import_kzg" . }} + "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" "io" "errors" ) -// WriteTo writes binary encoding of Proof to w without point compression +// WriteRawTo writes binary encoding of Proof to w without point compression func (proof *Proof) WriteRawTo(w io.Writer) (int64, error) { return proof.writeTo(w, curve.RawEncoding()) } @@ -30,6 +32,7 @@ func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64 proof.BatchedProof.ClaimedValues, &proof.ZShiftedOpening.H, &proof.ZShiftedOpening.ClaimedValue, + proof.Bsb22Commitments, } for _, v := range toEncode { @@ -38,7 +41,7 @@ func (proof *Proof) writeTo(w io.Writer, options ...func(*curve.Encoder)) (int64 } } - return enc.BytesWritten(), nil + return enc.BytesWritten(), nil } // ReadFrom reads binary representation of Proof from r @@ -56,6 +59,7 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { &proof.BatchedProof.ClaimedValues, &proof.ZShiftedOpening.H, &proof.ZShiftedOpening.ClaimedValue, + &proof.Bsb22Commitments, } for _, v := range toDecode { @@ -64,13 +68,30 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { } } - return dec.BytesRead(), nil + if proof.Bsb22Commitments == nil { + proof.Bsb22Commitments = []kzg.Digest{} + } + + return dec.BytesRead(), nil } // WriteTo writes binary encoding of ProvingKey to w func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { + return pk.writeTo(w, true) +} + +// WriteRawTo writes binary encoding of ProvingKey to w without point compression +func (pk *ProvingKey) WriteRawTo(w io.Writer) (n int64, err error) { + return pk.writeTo(w, false) +} + +func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err error) { // encode the verifying key - n, err = pk.Vk.WriteTo(w) + if withCompression { + n, err = pk.Vk.WriteTo(w) + } else { + n, err = pk.Vk.WriteRawTo(w) + } if err != nil { return } @@ -88,8 +109,19 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } n += n2 + // KZG key + if withCompression { + n2, err = pk.Kzg.WriteTo(w) + } else { + n2, err = pk.Kzg.WriteRawTo(w) + } + if err != nil { + return + } + n += n2 + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { + if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -98,16 +130,17 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { // encode the size (nor does it convert from Montgomery to Regular form) // so we explicitly transmit []fr.Element toEncode := []interface{}{ - ([]fr.Element)(pk.Ql), - ([]fr.Element)(pk.Qr), - ([]fr.Element)(pk.Qm), - ([]fr.Element)(pk.Qo), - ([]fr.Element)(pk.CQk), - ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.S1Canonical), - ([]fr.Element)(pk.S2Canonical), - ([]fr.Element)(pk.S3Canonical), - pk.Permutation, + pk.trace.Ql.Coefficients(), + pk.trace.Qr.Coefficients(), + pk.trace.Qm.Coefficients(), + pk.trace.Qo.Coefficients(), + pk.trace.Qk.Coefficients(), + coefficients(pk.trace.Qcp), + pk.lQk.Coefficients(), + pk.trace.S1.Coefficients(), + pk.trace.S2.Coefficients(), + pk.trace.S3.Coefficients(), + pk.trace.S, } for _, v := range toEncode { @@ -121,46 +154,133 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { // ReadFrom reads from binary representation in r into ProvingKey func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { + return pk.readFrom(r, true) +} + +// UnsafeReadFrom reads from binary representation in r into ProvingKey without subgroup checks +func (pk *ProvingKey) UnsafeReadFrom(r io.Reader) (int64, error) { + return pk.readFrom(r, false) +} + +func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, error) { pk.Vk = &VerifyingKey{} n, err := pk.Vk.ReadFrom(r) if err != nil { return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) + if withSubgroupChecks { + n2, err = pk.Kzg.ReadFrom(r) + } else { + n2, err = pk.Kzg.UnsafeReadFrom(r) + } + n += n2 + if err != nil { + return n, err + } + + pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) - toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), - (*[]fr.Element)(&pk.Qr), - (*[]fr.Element)(&pk.Qm), - (*[]fr.Element)(&pk.Qo), - (*[]fr.Element)(&pk.CQk), - (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.S1Canonical), - (*[]fr.Element)(&pk.S2Canonical), - (*[]fr.Element)(&pk.S3Canonical), - &pk.Permutation, + + var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element + var qcp [][]fr.Element + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err } + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err + } + } + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + pk.trace.Ql = iop.NewPolynomial(&ql, canReg) + pk.trace.Qr = iop.NewPolynomial(&qr, canReg) + pk.trace.Qm = iop.NewPolynomial(&qm, canReg) + pk.trace.Qo = iop.NewPolynomial(&qo, canReg) + pk.trace.Qk = iop.NewPolynomial(&qk, canReg) + pk.trace.S1 = iop.NewPolynomial(&s1, canReg) + pk.trace.S2 = iop.NewPolynomial(&s2, canReg) + pk.trace.S3 = iop.NewPolynomial(&s3, canReg) + + pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) + for i := range qcp { + pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) + } + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + pk.lQk = iop.NewPolynomial(&lqk, lagReg) + + + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil @@ -169,6 +289,15 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { // WriteTo writes binary encoding of VerifyingKey to w func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { + return vk.writeTo(w) +} + +// WriteRawTo writes binary encoding of VerifyingKey to w without point compression +func (vk *VerifyingKey) WriteRawTo(w io.Writer) (int64, error) { + return vk.writeTo(w, curve.RawEncoding()) +} + +func (vk *VerifyingKey) writeTo(w io.Writer, options ...func(*curve.Encoder)) (n int64, err error) { enc := curve.NewEncoder(w) toEncode := []interface{}{ @@ -185,6 +314,11 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.Qm, &vk.Qo, &vk.Qk, + vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + vk.CommitmentConstraintIndexes, } for _, v := range toEncode { @@ -213,6 +347,11 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.Qm, &vk.Qo, &vk.Qk, + &vk.Qcp, + &vk.Kzg.G1, + &vk.Kzg.G2[0], + &vk.Kzg.G2[1], + &vk.CommitmentConstraintIndexes, } for _, v := range toDecode { @@ -221,5 +360,9 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { } } + if vk.Qcp == nil { + vk.Qcp = []kzg.Digest{} + } + return dec.BytesRead(), nil } \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index af921fded2..def44c21ec 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -5,6 +5,8 @@ import ( "time" "sync" + "github.com/consensys/gnark/backend/witness" + {{ template "import_fr" . }} {{ template "import_curve" . }} {{ template "import_kzg" . }} @@ -12,10 +14,12 @@ import ( {{ template "import_backend_cs" . }} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/logger" "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/constraint" ) type Proof struct { @@ -29,17 +33,59 @@ type Proof struct { // Commitments to h1, h2, h3 such that h = h1 + Xh2 + X**2h3 is the quotient polynomial H [3]kzg.Digest - // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2 + Bsb22Commitments []kzg.Digest + + // Batch opening proof of h1 + zeta*h2 + zeta**2h3, linearizedPolynomial, l, r, o, s1, s2, qCPrime BatchedProof kzg.BatchOpeningProof // Opening proof of Z at zeta*mu ZShiftedOpening kzg.OpeningProof } -// Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { +// Computing and verifying Bsb22 multi-commits explained in https://hackmd.io/x8KsadW3RRyX7YTCFJIkHg +func bsb22ComputeCommitmentHint(spr *cs.SparseR1CS, pk *ProvingKey, proof *Proof, cCommitments []*iop.Polynomial, res *fr.Element, commDepth int) solver.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] + committedValues := make([]fr.Element, pk.Domain[0].Cardinality) + offset := spr.GetNbPublicVariables() + for i := range ins { + committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) + } + var ( + err error + hashRes []fr.Element + ) + if _, err = committedValues[offset+commitmentInfo.CommitmentIndex].SetRandom(); err != nil { // Commitment injection constraint has qcp = 0. Safe to use for blinding. + return err + } + if _, err = committedValues[offset+spr.GetNbConstraints()-1].SetRandom(); err != nil { // Last constraint has qcp = 0. Safe to use for blinding + return err + } + pi2iop := iop.NewPolynomial(&committedValues, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}) + cCommitments[commDepth] = pi2iop.ShallowClone() + cCommitments[commDepth].ToCanonical(&pk.Domain[0]).ToRegular() + if proof.Bsb22Commitments[commDepth], err = kzg.Commit(cCommitments[commDepth].Coefficients(), pk.Kzg); err != nil { + return err + } + if hashRes, err = fr.Hash(proof.Bsb22Commitments[commDepth].Marshal(), []byte("BSB22-Plonk"), 1); err != nil { + return err + } + res.Set(&hashRes[0]) // TODO @Tabaie use CommitmentIndex for this; create a new variable CommitmentConstraintIndex for other uses + res.BigInt(outs[0]) + + return nil + } +} + +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { + + log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", spr.GetNbConstraints()).Str("backend", "plonk").Logger() + + opt, err := backend.NewProverConfig(opts...) + if err != nil { + return nil, err + } - log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() start := time.Now() // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -50,53 +96,108 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen // result proof := &Proof{} - // compute the constraint system solution - var solution []fr.Element - var err error - if solution, err = spr.Solve(fullWitness, opt); err != nil { - if !opt.Force { - return nil, err - } else { - // we need to fill solution with random values - var r fr.Element - _, _ = r.SetRandom() - for i := len(spr.Public) + len(spr.Secret); i < len(solution); i++ { - solution[i] = r - r.Double(&r) - } - } + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + commitmentVal := make([]fr.Element, len(commitmentInfo)) // TODO @Tabaie get rid of this + cCommitments := make([]*iop.Polynomial, len(commitmentInfo)) + proof.Bsb22Commitments = make([]kzg.Digest, len(commitmentInfo)) + for i := range commitmentInfo { + opt.SolverOpts = append(opt.SolverOpts, solver.OverrideHint(commitmentInfo[i].HintID, + bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) + } + + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) } // query l, r, o in Lagrange basis, not blinded - evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) + _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) + if err != nil { + return nil, err + } + // TODO @gbotrel deal with that conversion lazily + lcCommitments := make([]*iop.Polynomial, len(cCommitments)) + for i := range cCommitments { + lcCommitments[i] = cCommitments[i].Clone(int(pk.Domain[1].Cardinality)).ToLagrangeCoset(&pk.Domain[1]) // lagrange coset form + } + solution := _solution.(*cs.SparseR1CSSolution) + evaluationLDomainSmall := []fr.Element(solution.L) + evaluationRDomainSmall := []fr.Element(solution.R) + evaluationODomainSmall := []fr.Element(solution.O) lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - liop := iop.NewPolynomial(&evaluationLDomainSmall,lagReg) - riop := iop.NewPolynomial(&evaluationRDomainSmall,lagReg) - oiop := iop.NewPolynomial(&evaluationODomainSmall,lagReg) - wliop := liop.ShallowClone() - wriop := riop.ShallowClone() - woiop := oiop.ShallowClone() - wliop.ToCanonical(&pk.Domain[0]).ToRegular() - wriop.ToCanonical(&pk.Domain[0]).ToRegular() - woiop.ToCanonical(&pk.Domain[0]).ToRegular() - - // Blind l, r, o before committing - // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. - bwliop := wliop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwriop := wriop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - bwoiop := woiop.Clone(int(pk.Domain[1].Cardinality)).Blind(1) - if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Vk.KZGSRS); err != nil { - return nil, err + // l, r, o and blinded versions + var ( + wliop, + wriop, + woiop, + bwliop, + bwriop, + bwoiop *iop.Polynomial + ) + var wgLRO sync.WaitGroup + wgLRO.Add(3) + go func() { + // we keep in lagrange regular form since iop.BuildRatioCopyConstraint prefers it in this form. + wliop = iop.NewPolynomial(&evaluationLDomainSmall, lagReg) + // we set the underlying slice capacity to domain[1].Cardinality to minimize mem moves. + bwliop = wliop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + wriop = iop.NewPolynomial(&evaluationRDomainSmall, lagReg) + bwriop = wriop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + go func() { + woiop = iop.NewPolynomial(&evaluationODomainSmall, lagReg) + bwoiop = woiop.Clone(int(pk.Domain[1].Cardinality)).ToCanonical(&pk.Domain[0]).ToRegular().Blind(1) + wgLRO.Done() + }() + + fw, ok := fullWitness.Vector().(fr.Vector) + if !ok { + return nil, witness.ErrInvalidWitness } + // start computing lcqk + var lcqk *iop.Polynomial + chLcqk := make(chan struct{}, 1) + go func() { + // compute qk in canonical basis, completed with the public inputs + // We copy the coeffs of qk to pk is not mutated + lqkcoef := pk.lQk.Coefficients() + qkCompletedCanonical := make([]fr.Element, len(lqkcoef)) + copy(qkCompletedCanonical, fw[:len(spr.Public)]) + copy(qkCompletedCanonical[len(spr.Public):], lqkcoef[len(spr.Public):]) + for i := range commitmentInfo { + qkCompletedCanonical[spr.GetNbPublicVariables()+commitmentInfo[i].CommitmentIndex] = commitmentVal[i] + } + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} + lcqk = iop.NewPolynomial(&qkCompletedCanonical, canReg) + lcqk.ToLagrangeCoset(&pk.Domain[1]) + close(chLcqk) + }() + // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *pk.Vk, fullWitness[:len(spr.Public)]); err != nil { + if err := bindPublicData(&fs, "gamma", *pk.Vk, fw[:len(spr.Public)], proof.Bsb22Commitments); err != nil { + return nil, err + } + + // wait for polys to be blinded + wgLRO.Wait() + if err := commitToLRO(bwliop.Coefficients(), bwriop.Coefficients(), bwoiop.Coefficients(), proof, pk.Kzg); err != nil { return nil, err } - gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) + + gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) // TODO @Tabaie @ThomasPiellard add BSB commitment here? if err != nil { return nil, err } @@ -109,17 +210,30 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen var beta fr.Element beta.SetBytes(bbeta) + // l, r, o are already blinded + wgLRO.Add(3) + go func() { + bwliop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwriop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + go func() { + bwoiop.ToLagrangeCoset(&pk.Domain[1]) + wgLRO.Done() + }() + // compute the copy constraint's ratio - // We copy liop, riop, oiop because they are fft'ed in the process. - // We could have not copied them at the cost of doing one more bit reverse - // per poly... + // note that wliop, wriop and woiop are fft'ed (mutated) in the process. ziop, err := iop.BuildRatioCopyConstraint( []*iop.Polynomial{ - liop.Clone(), - riop.Clone(), - oiop.Clone(), + wliop, + wriop, + woiop, }, - pk.Permutation, + pk.trace.S, beta, gamma, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}, @@ -130,73 +244,32 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } // commit to the blinded version of z - bwziop := ziop // iop.NewWrappedPolynomial(&ziop) - bwziop.Blind(2) - proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Vk.KZGSRS, runtime.NumCPU()*2) - if err != nil { - return proof, err - } + chZ := make(chan error, 1) + var bwziop, bwsziop *iop.Polynomial + var alpha fr.Element + go func() { + bwziop = ziop // iop.NewWrappedPolynomial(&ziop) + bwziop.Blind(2) + proof.Z, err = kzg.Commit(bwziop.Coefficients(), pk.Kzg, runtime.NumCPU()*2) + if err != nil { + chZ <- err + } - // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) - alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) - if err != nil { - return proof, err - } + // derive alpha from the Comm(l), Comm(r), Comm(o), Com(Z) + alpha, err = deriveRandomness(&fs, "alpha", &proof.Z) + if err != nil { + chZ <- err + } - // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) - copy(qkCompletedCanonical, fullWitness[:len(spr.Public)]) - copy(qkCompletedCanonical[len(spr.Public):], pk.LQk[len(spr.Public):]) - pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) - fft.BitReverse(qkCompletedCanonical) - - // l, r, o are blinded here - bwliop.ToLagrangeCoset(&pk.Domain[1]) - bwriop.ToLagrangeCoset(&pk.Domain[1]) - bwoiop.ToLagrangeCoset(&pk.Domain[1]) - - lagrangeCosetBitReversed := iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse} - - // we don't mutate so no need to clone the coefficients from the proving key. - wqliop := iop.NewPolynomial(&pk.lQl, lagrangeCosetBitReversed) - wqriop := iop.NewPolynomial(&pk.lQr, lagrangeCosetBitReversed) - wqmiop := iop.NewPolynomial(&pk.lQm, lagrangeCosetBitReversed) - wqoiop := iop.NewPolynomial(&pk.lQo, lagrangeCosetBitReversed) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqkiop := iop.NewPolynomial(&qkCompletedCanonical, canReg) - wqkiop.ToLagrangeCoset(&pk.Domain[1]) - - // storing Id - id := make([]fr.Element, pk.Domain[1].Cardinality) - id[1].SetOne() - widiop := iop.NewPolynomial(&id, canReg) - widiop.ToLagrangeCoset(&pk.Domain[1]) - - // permutations in LagrangeCoset: we don't mutate so no need to clone the coefficients from the - // proving key. - ws1 := iop.NewPolynomial(&pk.lS1LagrangeCoset, lagrangeCosetBitReversed) - ws2 := iop.NewPolynomial(&pk.lS2LagrangeCoset, lagrangeCosetBitReversed) - ws3 := iop.NewPolynomial(&pk.lS3LagrangeCoset, lagrangeCosetBitReversed) - - // Store z(g*x), without reallocating a slice - bwsziop := bwziop.ShallowClone().Shift(1) - bwsziop.ToLagrangeCoset(&pk.Domain[1]) - - // L_{g^{0}} - cap := pk.Domain[1].Cardinality - if cap < pk.Domain[0].Cardinality { - cap = pk.Domain[0].Cardinality // sanity check - } - lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) - lone[0].SetOne() - loneiop := iop.NewPolynomial(&lone, lagReg) - wloneiop := loneiop.ToCanonical(&pk.Domain[0]). - ToRegular(). - ToLagrangeCoset(&pk.Domain[1]) + // Store z(g*x), without reallocating a slice + bwsziop = bwziop.ShallowClone().Shift(1) + bwsziop.ToLagrangeCoset(&pk.Domain[1]) + chZ <- nil + close(chZ) + }() // Full capture using latest gnark crypto... - fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element) fr.Element { + fic := func(fql, fqr, fqm, fqo, fqk, l, r, o fr.Element, pi2QcPrime []fr.Element) fr.Element { // TODO @Tabaie make use of the fact that qCPrime is a selector: sparse and binary var ic, tmp fr.Element @@ -207,20 +280,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen ic.Add(&ic, &tmp) tmp.Mul(&fqo, &o) ic.Add(&ic, &tmp).Add(&ic, &fqk) + nbComms := len(commitmentInfo) + for i := range commitmentInfo { + tmp.Mul(&pi2QcPrime[i], &pi2QcPrime[i+nbComms]) + ic.Add(&ic, &tmp) + } return ic } fo := func(l, r, o, fid, fs1, fs2, fs3, fz, fzs fr.Element) fr.Element { - var uu fr.Element - u := pk.Domain[0].FrMultiplicativeGen - uu.Mul(&u, &u) - + u := &pk.Domain[0].FrMultiplicativeGen var a, b, tmp fr.Element - a.Mul(&beta, &fid).Add(&a, &l).Add(&a, &gamma) - tmp.Mul(&beta, &u).Mul(&tmp, &fid).Add(&tmp, &r).Add(&tmp, &gamma) + b.Mul(&beta, &fid) + a.Add(&b, &l).Add(&a, &gamma) + b.Mul(&b, u) + tmp.Add(&b, &r).Add(&tmp, &gamma) a.Mul(&a, &tmp) - tmp.Mul(&beta, &uu).Mul(&tmp, &fid).Add(&tmp, &o).Add(&tmp, &gamma) + tmp.Mul(&b, u).Add(&tmp, &o).Add(&tmp, &gamma) a.Mul(&a, &tmp).Mul(&a, &fz) b.Mul(&beta, &fs1).Add(&b, &l).Add(&b, &gamma) @@ -240,11 +317,11 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen return one } - // 0 , 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 - // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk,lone + // 0 , 1 , 2, 3 , 4 , 5 , 6 , 7, 8 , 9 , 10, 11, 12, 13, 14, 15:15+nbComm , 15+nbComm:15+2×nbComm + // l , r , o, id, s1, s2, s3, z, zs, ql, qr, qm, qo, qk ,lone, Bsb22Commitments, qCPrime fm := func(x ...fr.Element) fr.Element { - a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2]) + a := fic(x[9], x[10], x[11], x[12], x[13], x[0], x[1], x[2], x[15:]) b := fo(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8]) c := fone(x[7], x[14]) @@ -252,27 +329,49 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen return c } - testEval, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, + + // wait for lcqk + <-chLcqk + + // wait for Z part + if err := <-chZ; err != nil { + return proof, err + } + + // wait for l, r o lagrange coset conversion + wgLRO.Wait() + + toEval := []*iop.Polynomial{ bwliop, bwriop, bwoiop, - widiop, - ws1, - ws2, - ws3, + pk.lcIdIOP, + pk.lcS1, + pk.lcS2, + pk.lcS3, bwziop, bwsziop, - wqliop, - wqriop, - wqmiop, - wqoiop, - wqkiop, - wloneiop, - ) + pk.lcQl, + pk.lcQr, + pk.lcQm, + pk.lcQo, + lcqk, + pk.lLoneIOP, + } + toEval = append(toEval, lcCommitments...) // TODO: Add this at beginning + toEval = append(toEval, pk.lcQcp...) + systemEvaluation, err := iop.Evaluate(fm, iop.Form{Basis: iop.LagrangeCoset, Layout: iop.BitReverse}, toEval...) if err != nil { return nil, err } - h, err := iop.DivideByXMinusOne(testEval, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) + // open blinded Z at zeta*z + chbwzIOP := make(chan struct{}, 1) + go func() { + bwziop.ToCanonical(&pk.Domain[1]).ToRegular() + close(chbwzIOP) + }() + + h, err := iop.DivideByXMinusOne(systemEvaluation, [2]*fft.Domain{&pk.Domain[0], &pk.Domain[1]}) // TODO Rename to DivideByXNMinusOne or DivideByVanishingPoly etc if err != nil { return nil, err } @@ -282,7 +381,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen h.Coefficients()[:pk.Domain[0].Cardinality+2], h.Coefficients()[pk.Domain[0].Cardinality+2:2*(pk.Domain[0].Cardinality+2)], h.Coefficients()[2*(pk.Domain[0].Cardinality+2):3*(pk.Domain[0].Cardinality+2)], - proof, pk.Vk.KZGSRS); err != nil { + proof, pk.Kzg); err != nil { return nil, err } @@ -292,45 +391,72 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen return nil, err } - // compute evaluations of (blinded version of) l, r, o, z at zeta + // compute evaluations of (blinded version of) l, r, o, z, qCPrime at zeta var blzeta, brzeta, bozeta fr.Element - + qcpzeta := make([]fr.Element, len(commitmentInfo)) + var wgEvals sync.WaitGroup wgEvals.Add(3) - - go func() { - bwliop.ToCanonical(&pk.Domain[1]).ToRegular() - blzeta = bwliop.Evaluate(zeta) + evalAtZeta := func(poly *iop.Polynomial, res *fr.Element) { + poly.ToCanonical(&pk.Domain[1]).ToRegular() + *res = poly.Evaluate(zeta) wgEvals.Done() - }() + } + go evalAtZeta(bwliop, &blzeta) + go evalAtZeta(bwriop, &brzeta) + go evalAtZeta(bwoiop, &bozeta) + evalQcpAtZeta := func(begin, end int) { + for i := begin; i < end; i++ { + qcpzeta[i] = pk.trace.Qcp[i].Evaluate(zeta) + } + } + utils.Parallelize(len(commitmentInfo), evalQcpAtZeta) - go func() { - bwriop.ToCanonical(&pk.Domain[1]).ToRegular() - brzeta = bwriop.Evaluate(zeta) - wgEvals.Done() - }() - - go func() { - bwoiop.ToCanonical(&pk.Domain[1]).ToRegular() - bozeta = bwoiop.Evaluate(zeta) - wgEvals.Done() - }() - - // open blinded Z at zeta*z - bwziop.ToCanonical(&pk.Domain[1]).ToRegular() var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) + <-chbwzIOP proof.ZShiftedOpening, err = kzg.Open( bwziop.Coefficients()[:bwziop.BlindedSize()], zetaShifted, - pk.Vk.KZGSRS, + pk.Kzg, ) if err != nil { return nil, err } - // blinded z evaluated at u*zeta - bzuzeta := proof.ZShiftedOpening.ClaimedValue + // start to compute foldedH and foldedHDigest while computeLinearizedPolynomial runs. + computeFoldedH := make(chan struct{}, 1) + var foldedH []fr.Element + var foldedHDigest kzg.Digest + go func() { + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) + var bZetaPowerm, bSize big.Int + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + var zetaPowerm fr.Element + zetaPowerm.Exp(zeta, &bSize) + zetaPowerm.BigInt(&bZetaPowerm) + foldedHDigest = proof.H[2] + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) + + // foldedH = h1 + ζ*h2 + ζ²*h3 + foldedH = h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] + h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] + utils.Parallelize(len(foldedH), func(start, end int) { + for i := start; i < end; i++ { + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 + } + }) + close(computeFoldedH) + }() + + wgEvals.Wait() // wait for the evaluations var ( linearizedPolynomialCanonical []fr.Element @@ -338,10 +464,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen errLPoly error ) - wgEvals.Wait() // wait for the evaluations + // blinded z evaluated at u*zeta + bzuzeta := proof.ZShiftedOpening.ClaimedValue // compute the linearization polynomial r at zeta // (goal: save committing separately to z, ql, qr, qm, qo, k + // note: we linearizedPolynomialCanonical reuses bwziop memory linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, @@ -351,66 +479,52 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen gamma, zeta, bzuzeta, + qcpzeta, bwziop.Coefficients()[:bwziop.BlindedSize()], + coefficients(cCommitments), pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) - - // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) - var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) - var zetaPowerm fr.Element - zetaPowerm.Exp(zeta, &bSize) - zetaPowerm.BigInt(&bZetaPowerm) - foldedHDigest := proof.H[2] - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - - // foldedH = h1 + ζ*h2 + ζ²*h3 - foldedH := h.Coefficients()[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - h2 := h.Coefficients()[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] - h1 := h.Coefficients()[:pk.Domain[0].Cardinality+2] - utils.Parallelize(len(foldedH), func(start, end int) { - for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² - foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 - } - }) - + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Kzg, runtime.NumCPU()*2) if errLPoly != nil { return nil, errLPoly } + // wait for foldedH and foldedHDigest + <-computeFoldedH + // Batch open the first list of polynomials + polysQcp := coefficients(pk.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + // offset := len(polysQcp) + polysToOpen[0] = foldedH + polysToOpen[1] = linearizedPolynomialCanonical + polysToOpen[2] = bwliop.Coefficients()[:bwliop.BlindedSize()] + polysToOpen[3] = bwriop.Coefficients()[:bwriop.BlindedSize()] + polysToOpen[4] = bwoiop.Coefficients()[:bwoiop.BlindedSize()] + polysToOpen[5] = pk.trace.S1.Coefficients() + polysToOpen[6] = pk.trace.S2.Coefficients() + + digestsToOpen := make([]curve.G1Affine, len(pk.Vk.Qcp)+7) + copy(digestsToOpen[7:], pk.Vk.Qcp) + // offset = len(pk.Vk.Qcp) + digestsToOpen[0] = foldedHDigest + digestsToOpen[1] = linearizedPolynomialDigest + digestsToOpen[2] = proof.LRO[0] + digestsToOpen[3] = proof.LRO[1] + digestsToOpen[4] = proof.LRO[2] + digestsToOpen[5] = pk.Vk.S[0] + digestsToOpen[6] = pk.Vk.S[1] + proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - [][]fr.Element{ - foldedH, - linearizedPolynomialCanonical, - bwliop.Coefficients()[:bwliop.BlindedSize()], - bwriop.Coefficients()[:bwriop.BlindedSize()], - bwoiop.Coefficients()[:bwoiop.BlindedSize()], - pk.S1Canonical, - pk.S2Canonical, - }, - []kzg.Digest{ - foldedHDigest, - linearizedPolynomialDigest, - proof.LRO[0], - proof.LRO[1], - proof.LRO[2], - pk.Vk.S[0], - pk.Vk.S[1], - }, + polysToOpen, + digestsToOpen, zeta, hFunc, - pk.Vk.KZGSRS, + pk.Kzg, ) log.Debug().Dur("took", time.Since(start)).Msg("prover done") @@ -423,21 +537,29 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backen } +func coefficients(p []*iop.Polynomial) [][]fr.Element { + res := make([][]fr.Element, len(p)) + for i, pI := range p { + res[i] = pI.Coefficients() + } + return res +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) chCommit1 := make(chan struct{}, 1) go func() { - proof.LRO[0], err0 = kzg.Commit(bcl, srs, n) + proof.LRO[0], err0 = kzg.Commit(bcl, kzgPk, n) close(chCommit0) }() go func() { - proof.LRO[1], err1 = kzg.Commit(bcr, srs, n) + proof.LRO[1], err1 = kzg.Commit(bcr, kzgPk, n) close(chCommit1) }() - if proof.LRO[2], err2 = kzg.Commit(bco, srs, n); err2 != nil { + if proof.LRO[2], err2 = kzg.Commit(bco, kzgPk, n); err2 != nil { return err2 } <-chCommit0 @@ -450,20 +572,20 @@ func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { return err1 } -func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { - n := runtime.NumCPU() / 2 +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, kzgPk kzg.ProvingKey) error { + n := runtime.NumCPU() var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) chCommit1 := make(chan struct{}, 1) go func() { - proof.H[0], err0 = kzg.Commit(h1, srs, n) + proof.H[0], err0 = kzg.Commit(h1, kzgPk, n) close(chCommit0) }() go func() { - proof.H[1], err1 = kzg.Commit(h2, srs, n) + proof.H[1], err1 = kzg.Commit(h2, kzgPk, n) close(chCommit1) }() - if proof.H[2], err2 = kzg.Commit(h3, srs, n); err2 != nil { + if proof.H[2], err2 = kzg.Commit(h3, kzgPk, n); err2 != nil { return err2 } <-chCommit0 @@ -476,41 +598,6 @@ func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error return err1 } -// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. -// solution = [ public | secret | internal ] -func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - - s := int(pk.Domain[0].Cardinality) - - var l, r, o []fr.Element - l = make([]fr.Element, s) - r = make([]fr.Element, s) - o = make([]fr.Element, s) - s0 := solution[0] - - for i := 0; i < len(spr.Public); i++ { // placeholders - l[i] = solution[i] - r[i] = s0 - o[i] = s0 - } - offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // constraints - l[offset+i] = solution[spr.Constraints[i].L.WireID()] - r[offset+i] = solution[spr.Constraints[i].R.WireID()] - o[offset+i] = solution[spr.Constraints[i].O.WireID()] - } - offset += len(spr.Constraints) - - for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solution[0]) - l[offset+i] = s0 - r[offset+i] = s0 - o[offset+i] = s0 - } - - return l, r, o - -} - // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta @@ -522,7 +609,7 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element @@ -533,13 +620,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - ps1 := iop.NewPolynomial(&pk.S1Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - s1 = ps1.Evaluate(zeta) // s1(ζ) + s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := ps2.Evaluate(zeta) // s2(ζ) + // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) + tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -570,43 +656,51 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := make([]fr.Element, len(blindedZCanonical)) - copy(linPol, blindedZCanonical) - - utils.Parallelize(len(linPol), func(start, end int) { + s3canonical := pk.trace.S3.Coefficients() + utils.Parallelize(len(blindedZCanonical), func(start, end int) { - var t0, t1 fr.Element + var t, t0, t1 fr.Element for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + t.Mul(&blindedZCanonical[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - if i < len(pk.S3Canonical) { + if i < len(s3canonical) { - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + t0.Mul(&s3canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) - linPol[i].Add(&linPol[i], &t0) + t.Add(&t, &t0) } - linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) + t.Mul(&t, &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) - if i < len(pk.Qm) { + cql := pk.trace.Ql.Coefficients() + cqr := pk.trace.Qr.Coefficients() + cqm := pk.trace.Qm.Coefficients() + cqo := pk.trace.Qo.Coefficients() + cqk := pk.trace.Qk.Coefficients() + if i < len(cqm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) - t0.Mul(&pk.Ql[i], &lZeta) + t1.Mul(&cqm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&cql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) + t.Add(&t, &t0) // linPol = linPol + l(ζ)*Ql(X) + + t0.Mul(&cqr[i], &rZeta) + t.Add(&t, &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qr[i], &rZeta) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) + t0.Mul(&cqo[i], &oZeta).Add(&t0, &cqk[i]) + t.Add(&t, &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) - t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) + for j := range qcpZeta { + t0.Mul(&pi2Canonical[j][i], &qcpZeta[j]) + t.Add(&t, &t0) + } } t0.Mul(&blindedZCanonical[i], &lagrangeZeta) - linPol[i].Add(&linPol[i], &t0) // finish the computation + blindedZCanonical[i].Add(&t, &t0) // finish the computation } }) - return linPol + return blindedZCanonical } diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index 810078c132..d2f6a0ed60 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -5,48 +5,32 @@ import ( {{- template "import_fft" . }} {{- template "import_backend_cs" . }} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" - - kzgg "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk/internal" + "github.com/consensys/gnark/constraint" + "sync" ) -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation -type ProvingKey struct { - // Verifying Key is embedded into the proving key (needed by Prove) - Vk *VerifyingKey - - // TODO store iop.Polynomial here, not []fr.Element for more "type safety" - - // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo []fr.Element - - // qr,ql,qm,qo (in lagrange coset basis) --> these are not serialized, but computed from Ql, Qr, Qm, Qo once. - lQl, lQr, lQm, lQo []fr.Element - - // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. - // Storing LQk in Lagrange basis saves a fft... - CQk, LQk []fr.Element - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain - // Domain[0], Domain[1] fft.Domain - - // Permutation polynomials - S1Canonical, S2Canonical, S3Canonical []fr.Element - - // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. - lS1LagrangeCoset, lS2LagrangeCoset, lS3LagrangeCoset []fr.Element - - // position -> permuted position (position in [0,3*sizeSystem-1]) - Permutation []int64 +// Trace stores a plonk trace as columns +type Trace struct { + + // Constants describing a plonk circuit. The first entries + // of LQk (whose index correspond to the public inputs) are set to 0, and are to be + // completed by the prover. At those indices i (so from 0 to nb_public_variables), LQl[i]=-1 + // so the first nb_public_variables constraints look like this: + // -1*Wire[i] + 0* + 0 . It is zero when the constant coefficient is replaced by Wire[i]. + Ql, Qr, Qm, Qo, Qk *iop.Polynomial + Qcp []*iop.Polynomial + + // Polynomials representing the splitted permutation. The full permutation's support is 3*N where N=nb wires. + // The set of interpolation is of size N, so to represent the permutation S we let S acts on the + // set A=(, u*, u^{2}*) of size 3*N, where u is outside (its use is to shift the set ). + // We obtain a permutation of A, A'. We split A' in 3 (A'_{1}, A'_{2}, A'_{3}), and S1, S2, S3 are + // respectively the interpolation of A'_{1}, A'_{2}, A'_{3} on . + S1, S2, S3 *iop.Polynomial + + // S full permutation, i -> S[i] + S []int64 } // VerifyingKey stores the data needed to verify a proof: @@ -55,6 +39,7 @@ type ProvingKey struct { // * Commitments of qr, qm, qo, qk prepended with as many zeroes as there are public inputs // * Commitments to S1, S2, S3 type VerifyingKey struct { + // Size circuit Size uint64 SizeInv fr.Element @@ -62,7 +47,7 @@ type VerifyingKey struct { NbPublicVariables uint64 // Commitment scheme that is used for an instantiation of PLONK - KZGSRS *kzg.SRS + Kzg kzg.VerifyingKey // cosetShift generator of the coset on the small domain CosetShift fr.Element @@ -70,120 +55,300 @@ type VerifyingKey struct { // S commitments to S1, S2, S3 S [3]kzg.Digest - // Commitments to ql, qr, qm, qo prepended with as many zeroes (ones for l) as there are public inputs. + // Commitments to ql, qr, qm, qo, qcp prepended with as many zeroes (ones for l) as there are public inputs. // In particular Qk is not complete. Ql, Qr, Qm, Qo, Qk kzg.Digest + Qcp []kzg.Digest + + CommitmentConstraintIndexes []uint64 +} + +// ProvingKey stores the data needed to generate a proof: +// * the commitment scheme +// * ql, prepended with as many ones as they are public inputs +// * qr, qm, qo prepended with as many zeroes as there are public inputs. +// * qk, prepended with as many zeroes as public inputs, to be completed by the prover +// with the list of public inputs. +// * sigma_1, sigma_2, sigma_3 in both basis +// * the copy constraint permutation +type ProvingKey struct { + + // stores ql, qr, qm, qo, qk (-> to be completed by the prover) + // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used + // for computing the opening proofs (hence the canonical form). The canonical version + // of qk incomplete is used in the linearisation polynomial. + // The polynomials in trace are in canonical basis. + trace Trace + + Kzg kzg.ProvingKey + + // Verifying Key is embedded into the proving key (needed by Prove) + Vk *VerifyingKey + + // qr,ql,qm,qo,qcp in LagrangeCoset --> these are not serialized, but computed from Ql, Qr, Qm, Qo, Qcp once. + lcQl, lcQr, lcQm, lcQo *iop.Polynomial + lcQcp []*iop.Polynomial + + // LQk qk in Lagrange form -> to be completed by the prover. After being completed, + lQk *iop.Polynomial + + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + + // in lagrange coset basis --> these are not serialized, but computed from S1Canonical, S2Canonical, S3Canonical once. + lcS1, lcS2, lcS3 *iop.Polynomial + + // in lagrange coset basis --> not serialized id and L_{g^{0}} + lcIdIOP, lLoneIOP *iop.Polynomial } -// Setup sets proving and verifying keys -func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { + var pk ProvingKey var vk VerifyingKey - - // The verifying key shares data with the proving key pk.Vk = &vk + vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) - nbConstraints := len(spr.Constraints) + // step 0: set the fft domains + pk.initDomains(spr) - // fft domains - sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) + if len(kzgSrs.Pk.G1) < int(vk.Size) { + return nil, nil, errors.New("kzg srs is too small") + } + pk.Kzg = kzgSrs.Pk + vk.Kzg = kzgSrs.Vk - if err := pk.InitKZG(srs); err != nil { + // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis + BuildTrace(spr, &pk.trace) + + // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &pk.trace, nbVariables) + s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) + pk.trace.S1 = s[0] + pk.trace.S2 = s[1] + pk.trace.S3 = s[2] + + // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. + // All the above polynomials are expressed in canonical basis afterwards. This is why + // we save lqk before, because the prover needs to complete it in Lagrange form, and + // then express it on the Lagrange coset basis. + pk.lQk = pk.trace.Qk.Clone() // it will be completed by the prover, and the evaluated on the coset + err := commitTrace(&pk.trace, &pk) + if err != nil { return nil, nil, err } - // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) - pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) - pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) - pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) - - for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant - pk.Ql[i].SetOne().Neg(&pk.Ql[i]) - pk.Qr[i].SetZero() - pk.Qm[i].SetZero() - pk.Qo[i].SetZero() - pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // → to be completed by the prover + // step 5: evaluate ql, qr, qm, qo, s1, s2, s3 on LagrangeCoset (NOT qk) + // we clone them, because the canonical versions are going to be used in + // the opening proof + pk.computeLagrangeCosetPolys() + + return &pk, &vk, nil +} + +// computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset +// basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. +func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) + pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) + for i, qcpI := range pk.trace.Qcp { + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) + } + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + // storing Id + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + id := make([]fr.Element, pk.Domain[1].Cardinality) + id[0].Set(&pk.Domain[1].FrMultiplicativeGen) + for i := 1; i < int(pk.Domain[1].Cardinality); i++ { + id[i].Mul(&id[i-1], &pk.Domain[1].Generator) + } + pk.lcIdIOP = iop.NewPolynomial(&id, lagReg) + + // L_{g^{0}} + cap := pk.Domain[1].Cardinality + if cap < pk.Domain[0].Cardinality { + cap = pk.Domain[0].Cardinality // sanity check + } + lone := make([]fr.Element, pk.Domain[0].Cardinality, cap) + lone[0].SetOne() + pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). + ToRegular(). + ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() +} + + +// NbPublicWitness returns the expected public witness size (number of field elements) +func (vk *VerifyingKey) NbPublicWitness() int { + return int(vk.NbPublicVariables) +} + +// VerifyingKey returns pk.Vk +func (pk *ProvingKey) VerifyingKey() interface{} { + return pk.Vk +} + +// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. +// Size is the size of the system that is nb_constraints+nb_public_variables +func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) + size := ecc.NextPowerOfTwo(sizeSystem) + commitmentInfo := spr.CommitmentInfo.(constraint.PlonkCommitments) + + ql := make([]fr.Element, size) + qr := make([]fr.Element, size) + qm := make([]fr.Element, size) + qo := make([]fr.Element, size) + qk := make([]fr.Element, size) + qcp := make([][]fr.Element, len(commitmentInfo)) + + for i := 0; i < len(spr.Public); i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistent + ql[i].SetOne().Neg(&ql[i]) + qr[i].SetZero() + qm[i].SetZero() + qo[i].SetZero() + qk[i].SetZero() // → to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.Ql[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.Qr[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.Qm[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.Qm[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.Qo[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.CQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) + + + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + ql[offset+j].Set(&spr.Coefficients[c.QL]) + qr[offset+j].Set(&spr.Coefficients[c.QR]) + qm[offset+j].Set(&spr.Coefficients[c.QM]) + qo[offset+j].Set(&spr.Coefficients[c.QO]) + qk[offset+j].Set(&spr.Coefficients[c.QC]) + j++ } - pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) - pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) - pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) - fft.BitReverse(pk.Ql) - fft.BitReverse(pk.Qr) - fft.BitReverse(pk.Qm) - fft.BitReverse(pk.Qo) - fft.BitReverse(pk.CQk) + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + + pt.Ql = iop.NewPolynomial(&ql, lagReg) + pt.Qr = iop.NewPolynomial(&qr, lagReg) + pt.Qm = iop.NewPolynomial(&qm, lagReg) + pt.Qo = iop.NewPolynomial(&qo, lagReg) + pt.Qk = iop.NewPolynomial(&qk, lagReg) + pt.Qcp = make([]*iop.Polynomial, len(qcp)) - // build permutation. Note: at this stage, the permutation takes in account the placeholders - buildPermutation(spr, &pk) + for i := range commitmentInfo { + qcp[i] = make([]fr.Element, size) + for _, committed := range commitmentInfo[i].Committed { + qcp[i][offset+committed].SetOne() + } + pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + } +} - // set s1, s2, s3 - ccomputePermutationPolynomials(&pk) +// commitTrace commits to every polynomial in the trace, and put +// the commitments int the verifying key. +func commitTrace(trace *Trace, pk *ProvingKey) error { - // compute the lagrange coset basis versions (not serialized) - pk.computeLagrangeCosetPolys() + trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() + trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete + trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() + trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() - // Commit to the polynomials to set up the verifying key var err error - if vk.Ql, err = kzg.Commit(pk.Ql, vk.KZGSRS); err != nil { - return nil, nil, err + pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + for i := range trace.Qcp { + trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() + if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + return err + } } - if vk.Qr, err = kzg.Commit(pk.Qr, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.Qm, err = kzg.Commit(pk.Qm, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.Qo, err = kzg.Commit(pk.Qo, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + return err } - if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { - return nil, nil, err + if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + return err + } + if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + return err } + return nil +} - return &pk, &vk, nil +func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { + + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + pk.Domain[0] = *fft.NewDomain(sizeSystem) + + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) + } else { + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) + } } @@ -191,38 +356,45 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l∥r∥o) = (l∥r∥o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l∥r∥o is the concatenation of the indices of l, r, o in +// , where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { +func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := int(pk.Domain[0].Cardinality) + // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + sizeSolution := len(pt.Ql.Coefficients()) + sizePermutation := 3 * sizeSolution // init permutation - pk.Permutation = make([]int64, 3*sizeSolution) - for i := 0; i < len(pk.Permutation); i++ { - pk.Permutation[i] = -1 + permutation := make([]int64, sizePermutation) + for i := 0; i < len(permutation); i++ { + permutation[i] = -1 } // init LRO position -> variable_ID - lro := make([]int, 3*sizeSolution) // position -> variable_ID + lro := make([]int, sizePermutation) // position -> variable_ID for i := 0; i < len(spr.Public); i++ { lro[i] = i // IDs of LRO associated to placeholders (only L needs to be taken care of) } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + + j++ } + // init cycle: // map ID -> last position the ID was seen cycle := make([]int64, nbVariables) @@ -234,92 +406,54 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { if cycle[lro[i]] != -1 { // if != -1, it means we already encountered this value // so we need to set the corresponding permutation index. - pk.Permutation[i] = cycle[lro[i]] + permutation[i] = cycle[lro[i]] } cycle[lro[i]] = int64(i) } // complete the Permutation by filling the first IDs encountered - for i := 0; i < len(pk.Permutation); i++ { - if pk.Permutation[i] == -1 { - pk.Permutation[i] = cycle[lro[i]] + for i := 0; i < sizePermutation; i++ { + if permutation[i] == -1 { + permutation[i] = cycle[lro[i]] } } -} -func (pk *ProvingKey) computeLagrangeCosetPolys() { - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - wqliop := iop.NewPolynomial(clone(pk.Ql, pk.Domain[1].Cardinality), canReg) - wqriop := iop.NewPolynomial(clone(pk.Qr, pk.Domain[1].Cardinality), canReg) - wqmiop := iop.NewPolynomial(clone(pk.Qm, pk.Domain[1].Cardinality), canReg) - wqoiop := iop.NewPolynomial(clone(pk.Qo, pk.Domain[1].Cardinality), canReg) - - ws1 := iop.NewPolynomial(clone(pk.S1Canonical, pk.Domain[1].Cardinality), canReg) - ws2 := iop.NewPolynomial(clone(pk.S2Canonical, pk.Domain[1].Cardinality), canReg) - ws3 := iop.NewPolynomial(clone(pk.S3Canonical, pk.Domain[1].Cardinality), canReg) - - wqliop.ToLagrangeCoset(&pk.Domain[1]) - wqriop.ToLagrangeCoset(&pk.Domain[1]) - wqmiop.ToLagrangeCoset(&pk.Domain[1]) - wqoiop.ToLagrangeCoset(&pk.Domain[1]) - - ws1.ToLagrangeCoset(&pk.Domain[1]) - ws2.ToLagrangeCoset(&pk.Domain[1]) - ws3.ToLagrangeCoset(&pk.Domain[1]) - - pk.lQl = wqliop.Coefficients() - pk.lQr = wqriop.Coefficients() - pk.lQm = wqmiop.Coefficients() - pk.lQo = wqoiop.Coefficients() - - pk.lS1LagrangeCoset = ws1.Coefficients() - pk.lS2LagrangeCoset = ws2.Coefficients() - pk.lS3LagrangeCoset = ws3.Coefficients() + pt.S = permutation } -func clone(input []fr.Element, capacity uint64) *[]fr.Element { - res := make([]fr.Element, len(input), capacity) - copy(res, input) - return &res -} +// computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. +// We let the permutation act on || u || u^{2}, split the result in 3 parts, +// and interpolate each of the 3 parts on . +func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { -// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations -// s1, s2, s3. -// -// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | -// | -// | Permutation -// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v -// \---------------/ \--------------------/ \------------------------/ -// s1 (LDE) s2 (LDE) s3 (LDE) -func ccomputePermutationPolynomials(pk *ProvingKey) { + nbElmts := int(domain.Cardinality) - nbElmts := int(pk.Domain[0].Cardinality) + var res [3]*iop.Polynomial // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) + evaluationIDSmallDomain := getSupportPermutation(domain) // Lagrange form of S1, S2, S3 - pk.S1Canonical = make([]fr.Element, nbElmts) - pk.S2Canonical = make([]fr.Element, nbElmts) - pk.S3Canonical = make([]fr.Element, nbElmts) + s1Canonical := make([]fr.Element, nbElmts) + s2Canonical := make([]fr.Element, nbElmts) + s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) - pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) - pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) } - // Canonical form of S1, S2, S3 - pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) - pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) - fft.BitReverse(pk.S1Canonical) - fft.BitReverse(pk.S2Canonical) - fft.BitReverse(pk.S3Canonical) + lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} + res[0] = iop.NewPolynomial(&s1Canonical, lagReg) + res[1] = iop.NewPolynomial(&s2Canonical, lagReg) + res[2] = iop.NewPolynomial(&s3Canonical, lagReg) + + return res } -// getIDSmallDomain returns the Lagrange form of ID on the small domain -func getIDSmallDomain(domain *fft.Domain) []fr.Element { +// getSupportPermutation returns the support on which the permutation acts, it is +// || u || u^{2} +func getSupportPermutation(domain *fft.Domain) []fr.Element { res := make([]fr.Element, 3*domain.Cardinality) @@ -334,39 +468,4 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { } return res -} - -// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS -// -// This should be used after deserializing a ProvingKey -// as pk.Vk.KZG is NOT serialized -func (pk *ProvingKey) InitKZG(srs kzgg.SRS) error { - return pk.Vk.InitKZG(srs) -} - -// InitKZG inits vk.KZG using provided SRS -// -// This should be used after deserializing a VerifyingKey -// as vk.KZG is NOT serialized -// -// Note that this instantiate a new FFT domain using vk.Size -func (vk *VerifyingKey) InitKZG(srs kzgg.SRS) error { - _srs := srs.(*kzg.SRS) - - if len(_srs.G1) < int(vk.Size) { - return errors.New("kzg srs is too small") - } - vk.KZGSRS = _srs - - return nil -} - -// NbPublicWitness returns the expected public witness size (number of field elements) -func (vk *VerifyingKey) NbPublicWitness() int { - return int(vk.NbPublicVariables) -} - -// VerifyingKey returns pk.Vk -func (pk *ProvingKey) VerifyingKey() interface{} { - return pk.Vk -} +} \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index 1f7f8b83fc..43e5bd3c6f 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -5,6 +5,9 @@ import ( "time" "io" {{ template "import_fr" . }} + {{if eq .Curve "BN254"}} + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + {{end}} {{ template "import_kzg" . }} {{ template "import_curve" . }} {{if eq .Curve "BN254"}} @@ -20,7 +23,7 @@ var ( ) func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { - log := logger.Logger().With().Str("curve", "{{ toLower .CurveID }}").Str("backend", "plonk").Logger() + log := logger.Logger().With().Str("curve", "{{ toLower .Curve }}").Str("backend", "plonk").Logger() start := time.Now() // pick a hash function to derive the challenge (the same as in the prover) @@ -32,7 +35,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // The first challenge is derived using the public data: the commitments to the permutation, // the coefficients of the circuit, and the public inputs. // derive gamma from the Comm(blinded cl), Comm(blinded cr), Comm(blinded co) - if err := bindPublicData(&fs, "gamma", *vk, publicWitness); err != nil { + if err := bindPublicData(&fs, "gamma", *vk, publicWitness, proof.Bsb22Commitments); err != nil { return err } gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -66,25 +69,51 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = ∑_{i to be completed by the prover } offset := len(spr.Public) - for i := 0; i < nbConstraints; i++ { // constraints - - pk.EvaluationQlDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].L.CoeffID()]) - pk.EvaluationQrDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].R.CoeffID()]) - pk.EvaluationQmDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].M[0].CoeffID()]). - Mul(&pk.EvaluationQmDomainBigBitReversed[offset+i], &spr.Coefficients[spr.Constraints[i].M[1].CoeffID()]) - pk.EvaluationQoDomainBigBitReversed[offset+i].Set(&spr.Coefficients[spr.Constraints[i].O.CoeffID()]) - pk.LQkIncompleteDomainSmall[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) - pk.CQkIncomplete[offset+i].Set(&pk.LQkIncompleteDomainSmall[offset+i]) + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + pk.EvaluationQlDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QL]) + pk.EvaluationQrDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QR]) + pk.EvaluationQmDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QM]) + pk.EvaluationQoDomainBigBitReversed[offset+j].Set(&spr.Coefficients[c.QO]) + pk.LQkIncompleteDomainSmall[offset+j].Set(&spr.Coefficients[c.QC]) + pk.CQkIncomplete[offset+j].Set(&pk.LQkIncompleteDomainSmall[offset+j]) + + j++ } + pk.Domain[0].FFTInverse(pk.EvaluationQlDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationQrDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationQmDomainBigBitReversed[:pk.Domain[0].Cardinality], fft.DIF) @@ -192,10 +197,10 @@ func Setup(spr *cs.SparseR1CS) (*ProvingKey, *VerifyingKey, error) { return &pk, &vk, err } - pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationQlDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQrDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQmDomainBigBitReversed, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationQoDomainBigBitReversed, fft.DIF, fft.OnCoset()) // build permutation. Note: at this stage, the permutation takes in account the placeholders buildPermutation(spr, &pk) @@ -240,10 +245,14 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } offset := len(spr.Public) - for i := 0; i < len(spr.Constraints); i++ { // IDs of LRO associated to constraints - lro[offset+i] = spr.Constraints[i].L.WireID() - lro[sizeSolution+offset+i] = spr.Constraints[i].R.WireID() - lro[2*sizeSolution+offset+i] = spr.Constraints[i].O.WireID() + + j := 0 + it := spr.GetSparseR1CIterator() + for c := it.Next(); c!=nil; c = it.Next() { + lro[offset+j] = int(c.XA) + lro[sizeSolution+offset+j] = int(c.XB) + lro[2*sizeSolution+offset+j] = int(c.XC) + j++ } // init cycle: @@ -329,9 +338,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationId1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationId3BigDomain, fft.DIF, fft.OnCoset()) pk.Domain[0].FFTInverse(pk.EvaluationS1BigDomain[:pk.Domain[0].Cardinality], fft.DIF) pk.Domain[0].FFTInverse(pk.EvaluationS2BigDomain[:pk.Domain[0].Cardinality], fft.DIF) @@ -359,9 +368,9 @@ func computePermutationPolynomials(pk *ProvingKey, vk *VerifyingKey) error { if err != nil { return err } - pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, true) - pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationS1BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS2BigDomain, fft.DIF, fft.OnCoset()) + pk.Domain[1].FFT(pk.EvaluationS3BigDomain, fft.DIF, fft.OnCoset()) return nil diff --git a/internal/kvstore/kvstore.go b/internal/kvstore/kvstore.go new file mode 100644 index 0000000000..6c580cc2fd --- /dev/null +++ b/internal/kvstore/kvstore.go @@ -0,0 +1,38 @@ +// Package kvstore implements simple key-value store +// +// It is without synchronization and allows any comparable keys. The main use of +// this package is for sharing singletons when building a circuit. +package kvstore + +import ( + "reflect" +) + +type Store interface { + SetKeyValue(key, value any) + GetKeyValue(key any) (value any) +} + +type impl struct { + db map[any]any +} + +func New() Store { + return &impl{ + db: make(map[any]any), + } +} + +func (c *impl) SetKeyValue(key, value any) { + if !reflect.TypeOf(key).Comparable() { + panic("key type not comparable") + } + c.db[key] = value +} + +func (c *impl) GetKeyValue(key any) any { + if !reflect.TypeOf(key).Comparable() { + panic("key type not comparable") + } + return c.db[key] +} diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index be2ff23c98..9046cc8a50 100644 Binary files a/internal/stats/latest.stats and b/internal/stats/latest.stats differ diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go index 1f6ed7764c..88b147fb80 100644 --- a/internal/stats/snippet.go +++ b/internal/stats/snippet.go @@ -7,8 +7,8 @@ import ( "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/sw_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" @@ -116,10 +116,8 @@ func initSnippets() { dummyG2.Y.A1 = newVariable() // e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB) - resMillerLoop, _ := sw_bls12377.MillerLoop(api, []sw_bls12377.G1Affine{dummyG1}, []sw_bls12377.G2Affine{dummyG2}) + _, _ = sw_bls12377.Pair(api, []sw_bls12377.G1Affine{dummyG1}, []sw_bls12377.G2Affine{dummyG2}) - // performs the final expo - _ = sw_bls12377.FinalExponentiation(api, resMillerLoop) }, ecc.BW6_761) registerSnippet("pairing_bls24315", func(api frontend.API, newVariable func() frontend.Variable) { @@ -138,10 +136,8 @@ func initSnippets() { dummyG2.Y.B1.A1 = newVariable() // e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB) - resMillerLoop, _ := sw_bls24315.MillerLoop(api, []sw_bls24315.G1Affine{dummyG1}, []sw_bls24315.G2Affine{dummyG2}) + _, _ = sw_bls24315.Pair(api, []sw_bls24315.G1Affine{dummyG1}, []sw_bls24315.G2Affine{dummyG2}) - // performs the final expo - _ = sw_bls24315.FinalExponentiation(api, resMillerLoop) }, ecc.BW6_633) } diff --git a/internal/tinyfield/element.go b/internal/tinyfield/element.go index ec15e0514b..85d289e751 100644 --- a/internal/tinyfield/element.go +++ b/internal/tinyfield/element.go @@ -27,6 +27,7 @@ import ( "strconv" "strings" + "github.com/bits-and-blooms/bitset" "github.com/consensys/gnark-crypto/field/hash" "github.com/consensys/gnark-crypto/field/pool" ) @@ -299,7 +300,7 @@ func (z *Element) SetRandom() (*Element, error) { return nil, err } - // Clear unused bits in in the most signicant byte to increase probability + // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<= 0; i-- { - if zeroes[i] { + if zeroes.Test(uint(i)) { continue } res[i].Mul(&res[i], &accumulator) @@ -676,6 +677,11 @@ func (z *Element) Marshal() []byte { return b[:] } +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index d6c74b751d..93664bddb2 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -317,7 +317,8 @@ func TestElementReduce(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, s := range testValues { + for i := range testValues { + s := testValues[i] expected := s reduce(&s) _reduceGeneric(&expected) @@ -701,7 +702,8 @@ func TestElementAdd(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -736,11 +738,12 @@ func TestElementAdd(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -810,7 +813,8 @@ func TestElementSub(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -845,11 +849,12 @@ func TestElementSub(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -919,7 +924,8 @@ func TestElementMul(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -973,11 +979,12 @@ func TestElementMul(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -1055,7 +1062,8 @@ func TestElementDiv(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -1091,11 +1099,12 @@ func TestElementDiv(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -1166,7 +1175,8 @@ func TestElementExp(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int r.BigInt(&rb) @@ -1201,11 +1211,12 @@ func TestElementExp(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) - for _, b := range testValues { - + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int b.BigInt(&bBig) @@ -1277,7 +1288,8 @@ func TestElementSquare(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1349,7 +1361,8 @@ func TestElementInverse(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1421,7 +1434,8 @@ func TestElementSqrt(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1493,7 +1507,8 @@ func TestElementDouble(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element @@ -1565,7 +1580,8 @@ func TestElementNeg(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int a.BigInt(&aBig) var c Element diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go index 4f9e9a15e4..9ef47d3cda 100644 --- a/internal/tinyfield/vector.go +++ b/internal/tinyfield/vector.go @@ -19,8 +19,13 @@ package tinyfield import ( "bytes" "encoding/binary" + "fmt" "io" + "runtime" "strings" + "sync" + "sync/atomic" + "unsafe" ) // Vector represents a slice of Element. @@ -73,6 +78,66 @@ func (vector *Vector) WriteTo(w io.Writer) (int64, error) { return n, nil } +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[0:8]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + // ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. // Length of the vector must be encoded as a uint32 on the first 4 bytes. func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { @@ -130,3 +195,56 @@ func (vector Vector) Less(i, j int) bool { func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go index e1db416306..68a98e5fa9 100644 --- a/internal/tinyfield/vector_test.go +++ b/internal/tinyfield/vector_test.go @@ -17,6 +17,7 @@ package tinyfield import ( + "bytes" "github.com/stretchr/testify/require" "reflect" "sort" @@ -47,12 +48,16 @@ func TestVectorRoundTrip(t *testing.T) { b, err := v1.MarshalBinary() assert.NoError(err) - var v2 Vector + var v2, v3 Vector err = v2.UnmarshalBinary(b) assert.NoError(err) + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) } func TestVectorEmptyRoundTrip(t *testing.T) { @@ -63,10 +68,23 @@ func TestVectorEmptyRoundTrip(t *testing.T) { b, err := v1.MarshalBinary() assert.NoError(err) - var v2 Vector + var v2, v3 Vector err = v2.UnmarshalBinary(b) assert.NoError(err) + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr } diff --git a/internal/utils/convert.go b/internal/utils/convert.go index 7ca9ba1674..3bc19c6758 100644 --- a/internal/utils/convert.go +++ b/internal/utils/convert.go @@ -15,6 +15,7 @@ package utils import ( + "math" "math/big" "reflect" ) @@ -60,7 +61,7 @@ func FromInterface(input interface{}) big.Int { case int32: r.SetInt64(int64(v)) case int64: - r.SetInt64(int64(v)) + r.SetInt64(v) case int: r.SetInt64(int64(v)) case string: @@ -87,3 +88,31 @@ func FromInterface(input interface{}) big.Int { return r } + +func IntSliceSliceToUint64SliceSlice(in [][]int) [][]uint64 { + res := make([][]uint64, len(in)) + for i := range in { + res[i] = make([]uint64, len(in[i])) + for j := range in[i] { + if in[i][j] < 0 { + panic("negative value in int slice") + } + res[i][j] = uint64(in[i][j]) + } + } + return res +} + +func Uint64SliceSliceToIntSliceSlice(in [][]uint64) [][]int { + res := make([][]int, len(in)) + for i := range in { + res[i] = make([]int, len(in[i])) + for j := range in[i] { + if in[i][j] >= math.MaxInt { + panic("too large") + } + res[i][j] = int(in[i][j]) + } + } + return res +} diff --git a/internal/utils/heap.go b/internal/utils/heap.go new file mode 100644 index 0000000000..944bcb5e03 --- /dev/null +++ b/internal/utils/heap.go @@ -0,0 +1,63 @@ +package utils + +// An IntHeap is a min-heap of linear expressions. It facilitates merging k-linear expressions. +// +// The code is identical to https://pkg.go.dev/container/heap but replaces interfaces with concrete +// type to avoid memory overhead. +type IntHeap []int + +func (h *IntHeap) less(i, j int) bool { return (*h)[i] < (*h)[j] } +func (h *IntHeap) swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } + +// Heapify establishes the heap invariants required by the other routines in this package. +// Heapify is idempotent with respect to the heap invariants +// and may be called whenever the heap invariants may have been invalidated. +// The complexity is O(n) where n = len(*h). +func (h *IntHeap) Heapify() { + // heapify + n := len(*h) + for i := n/2 - 1; i >= 0; i-- { + h.down(i, n) + } +} + +// Pop removes and returns the minimum element (according to Less) from the heap. +// The complexity is O(log n) where n = len(*h). +// Pop is equivalent to Remove(h, 0). +func (h *IntHeap) Pop() { + n := len(*h) - 1 + h.swap(0, n) + h.down(0, n) + *h = (*h)[0:n] +} + +func (h *IntHeap) up(j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.less(j, i) { + break + } + h.swap(i, j) + j = i + } +} + +func (h *IntHeap) down(i0, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.less(j2, j1) { + j = j2 // = 2*i + 2 // right child + } + if !h.less(j, i) { + break + } + h.swap(i, j) + i = j + } + return i > i0 +} diff --git a/internal/utils/search.go b/internal/utils/search.go new file mode 100644 index 0000000000..fd9bcb460b --- /dev/null +++ b/internal/utils/search.go @@ -0,0 +1,26 @@ +package utils + +import "sort" + +// FindInSlice attempts to find the target in increasing slice x. +// If not found, returns false and the index where the target would be inserted. +func FindInSlice(x []int, target int) (int, bool) { + return sort.Find(len(x), func(i int) int { + return target - x[i] + }) +} + +// MultiListSeeker looks up increasing integers in a list of increasing lists of integers. +type MultiListSeeker [][]int + +// Seek returns the index of the earliest list where n is found, or -1 if not found. +func (s MultiListSeeker) Seek(n int) int { + for i, l := range s { + j, found := FindInSlice(l, n) + s[i] = l[j:] + if found { + return i + } + } + return -1 +} diff --git a/profile/profile.go b/profile/profile.go index 3d3778d995..fe3e58f3c2 100644 --- a/profile/profile.go +++ b/profile/profile.go @@ -92,7 +92,7 @@ func Start(options ...Option) *Profile { log := logger.Logger() if p.filePath == "" { - log.Warn().Msg("gnark profiling enabled [not writting to disk]") + log.Warn().Msg("gnark profiling enabled [not writing to disk]") } else { log.Info().Str("path", p.filePath).Msg("gnark profiling enabled") } @@ -131,7 +131,7 @@ func (p *Profile) Stop() { f.Close() log.Info().Str("path", p.filePath).Msg("gnark profiling disabled") } else { - log.Warn().Msg("gnark profiling disabled [not writting to disk]") + log.Warn().Msg("gnark profiling disabled [not writing to disk]") } } @@ -144,7 +144,7 @@ func (p *Profile) NbConstraints() int { // Top return a similar output than pprof top command func (p *Profile) Top() string { r := report.NewDefault(&p.pprof, report.Options{ - OutputFormat: report.Text, + OutputFormat: report.Tree, CompactLabels: true, NodeFraction: 0.005, EdgeFraction: 0.001, diff --git a/profile/profile_test.go b/profile/profile_test.go index 0505331a90..f198f370e1 100644 --- a/profile/profile_test.go +++ b/profile/profile_test.go @@ -15,8 +15,18 @@ type Circuit struct { A frontend.Variable } +type obj struct { +} + func (circuit *Circuit) Define(api frontend.API) error { - api.AssertIsEqual(api.Mul(circuit.A, circuit.A), circuit.A) + var o obj + o.Define(api, circuit.A) + // api.AssertIsEqual(api.Mul(circuit.A, circuit.A), circuit.A) + return nil +} + +func (o *obj) Define(api frontend.API, A frontend.Variable) error { + api.AssertIsEqual(api.Mul(A, A), A) return nil } @@ -29,12 +39,23 @@ func Example() { _, _ = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &Circuit{}) p.Stop() + // expected output fmt.Println(p.Top()) + const _ = `Showing nodes accounting for 2, 100% of 2 total +----------------------------------------------------------+------------- + flat flat% sum% cum cum% calls calls% + context +----------------------------------------------------------+------------- + 1 100% | profile_test.(*Circuit).Define profile/profile_test.go:21 + 1 50.00% 50.00% 1 50.00% | r1cs.(*builder).AssertIsEqual frontend/cs/r1cs/api_assertions.go:37 +----------------------------------------------------------+------------- + 1 100% | profile_test.(*Circuit).Define profile/profile_test.go:21 + 1 50.00% 100% 1 50.00% | r1cs.(*builder).Mul frontend/cs/r1cs/api.go:221 +----------------------------------------------------------+------------- + 0 0% 100% 2 100% | profile_test.(*Circuit).Define profile/profile_test.go:21 + 1 50.00% | r1cs.(*builder).AssertIsEqual frontend/cs/r1cs/api_assertions.go:37 + 1 50.00% | r1cs.(*builder).Mul frontend/cs/r1cs/api.go:221 +----------------------------------------------------------+-------------` + fmt.Println(p.NbConstraints()) - fmt.Println(p.Top()) // Output: // 2 - // Showing nodes accounting for 2, 100% of 2 total - // flat flat% sum% cum cum% - // 1 50.00% 50.00% 2 100% profile_test.(*Circuit).Define profile/profile_test.go:19 - // 1 50.00% 100% 1 50.00% r1cs.(*builder).AssertIsEqual frontend/cs/r1cs/api_assertions.go:37 } diff --git a/profile/profile_worker.go b/profile/profile_worker.go index b63499475b..97817fe898 100644 --- a/profile/profile_worker.go +++ b/profile/profile_worker.go @@ -5,6 +5,7 @@ import ( "strings" "sync" "sync/atomic" + "unicode" "github.com/google/pprof/profile" ) @@ -63,31 +64,18 @@ func collectSample(pc []uintptr) { for { frame, more := frames.Next() - if strings.HasSuffix(frame.Function, ".func1") { - // TODO @gbotrel filter anonymous func better - continue - } - - // to avoid aving a location that concentrates 99% of the calls, we transfer the "addConstraint" - // occuring in Mul to the previous level in the stack - if strings.Contains(frame.Function, "github.com/consensys/gnark/frontend/cs/r1cs.(*builder).Mul") { - continue + if strings.Contains(frame.Function, "frontend.parseCircuit") { + // we stop; previous frame was the .Define definition of the circuit + break } - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).Mul") { + if strings.HasSuffix(frame.Function, ".func1") { + // TODO @gbotrel filter anonymous func better continue } - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).split") { - continue - } - - // with scs.Builder (Plonk) Add and Sub always add a constraint --> we record the caller as the constraint adder - // but in the future we may record a different type of sample for these - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).Add") { - continue - } - if strings.HasPrefix(frame.Function, "github.com/consensys/gnark/frontend/cs/scs.(*scs).Sub") { + // filter internal builder functions + if filterSCSPrivateFunc(frame.Function) || filterR1CSPrivateFunc(frame.Function) { continue } @@ -107,6 +95,7 @@ func collectSample(pc []uintptr) { sessions[i].onceSetName.Do(func() { // once per profile session, we set the "name of the binary" // here we grep the struct name where "Define" exist: hopefully the circuit Name + // note: this won't work well for nested Define calls. fe := strings.Split(frame.Function, "/") circuitName := strings.TrimSuffix(fe[len(fe)-1], ".Define") sessions[i].pprof.Mapping = []*profile.Mapping{ @@ -114,7 +103,7 @@ func collectSample(pc []uintptr) { } }) } - break + // break --> we break when we hit frontend.parseCircuit; in case we have nested Define calls in the stack. } } @@ -123,3 +112,27 @@ func collectSample(pc []uintptr) { } } + +func filterSCSPrivateFunc(f string) bool { + const scsPrefix = "github.com/consensys/gnark/frontend/cs/scs.(*builder)." + if strings.HasPrefix(f, scsPrefix) && len(f) > len(scsPrefix) { + // filter plonk frontend private APIs from the trace. + c := []rune(f)[len(scsPrefix)] + if unicode.IsLower(c) { + return true + } + } + return false +} + +func filterR1CSPrivateFunc(f string) bool { + const r1csPrefix = "github.com/consensys/gnark/frontend/cs/r1cs.(*builder)." + if strings.HasPrefix(f, r1csPrefix) && len(f) > len(r1csPrefix) { + // filter r1cs frontend private APIs from the trace. + c := []rune(f)[len(r1csPrefix)] + if unicode.IsLower(c) { + return true + } + } + return false +} diff --git a/std/accumulator/merkle/verify.go b/std/accumulator/merkle/verify.go index 35caff2ce8..3ec92412d6 100644 --- a/std/accumulator/merkle/verify.go +++ b/std/accumulator/merkle/verify.go @@ -62,7 +62,7 @@ type MerkleProof struct { // leafSum returns the hash created from data inserted to form a leaf. // Without domain separation. -func leafSum(api frontend.API, h hash.Hash, data frontend.Variable) frontend.Variable { +func leafSum(api frontend.API, h hash.FieldHasher, data frontend.Variable) frontend.Variable { h.Reset() h.Write(data) @@ -73,7 +73,7 @@ func leafSum(api frontend.API, h hash.Hash, data frontend.Variable) frontend.Var // nodeSum returns the hash created from data inserted to form a leaf. // Without domain separation. -func nodeSum(api frontend.API, h hash.Hash, a, b frontend.Variable) frontend.Variable { +func nodeSum(api frontend.API, h hash.FieldHasher, a, b frontend.Variable) frontend.Variable { h.Reset() h.Write(a, b) @@ -86,7 +86,7 @@ func nodeSum(api frontend.API, h hash.Hash, a, b frontend.Variable) frontend.Var // true if the first element of the proof set is a leaf of data in the Merkle // root. False is returned if the proof set or Merkle root is nil, and if // 'numLeaves' equals 0. -func (mp *MerkleProof) VerifyProof(api frontend.API, h hash.Hash, leaf frontend.Variable) { +func (mp *MerkleProof) VerifyProof(api frontend.API, h hash.FieldHasher, leaf frontend.Variable) { depth := len(mp.Path) - 1 sum := leafSum(api, h, mp.Path[0]) diff --git a/std/accumulator/merkle/verify_test.go b/std/accumulator/merkle/verify_test.go index 2b650a2ca8..51582df727 100644 --- a/std/accumulator/merkle/verify_test.go +++ b/std/accumulator/merkle/verify_test.go @@ -19,17 +19,18 @@ package merkle import ( "bytes" "crypto/rand" + "os" + "testing" + "github.com/consensys/gnark-crypto/accumulator/merkletree" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/logger" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" - "os" - "testing" ) // MerkleProofTest used for testing only @@ -99,7 +100,7 @@ func TestVerify(t *testing.T) { os.Exit(-1) } - // verfiy the proof in plain go + // verify the proof in plain go verified := merkletree.VerifyProof(hGo, merkleRoot, proofPath, proofIndex, numLeaves) if !verified { t.Fatal("The merkle proof in plain go should pass") @@ -119,7 +120,7 @@ func TestVerify(t *testing.T) { t.Fatal(err) } logger.SetOutput(os.Stdout) - err = cc.IsSolved(w, backend.IgnoreSolverError(), backend.WithCircuitLogger(logger.Logger())) + err = cc.IsSolved(w, solver.WithLogger(logger.Logger())) if err != nil { t.Fatal(err) } diff --git a/std/algebra/doc.go b/std/algebra/doc.go new file mode 100644 index 0000000000..18e70aeac5 --- /dev/null +++ b/std/algebra/doc.go @@ -0,0 +1,14 @@ +// Package algebra implements: +// - base finite field 𝔽p arithmetic, +// - extension finite fields arithmetic (𝔽p², 𝔽p⁴, 𝔽p⁶, 𝔽p¹², 𝔽p²⁴), +// - short Weierstrass curve arithmetic over G1 (E/𝔽p) and G2 (Eₜ/𝔽p² or Eₜ/𝔽p⁴) +// - twisted Edwards curve arithmetic +// +// These arithmetic operations are implemented +// - using native field via the 2-chains BLS12-377/BW6-761 and BLS24-315/BW-633 +// (`native/`) or associated twisted Edwards (e.g. Jubjub/BLS12-381) and +// - using nonnative field via field emulation (`emulated/`). This allows to +// use any curve over any (SNARK) field (e.g. secp256k1 curve arithmetic over +// BN254 SNARK field or BN254 pairing over BN254 SNARK field). The drawback +// of this approach is the extreme cost of the operations. +package algebra diff --git a/std/algebra/emulated/fields_bls12381/doc.go b/std/algebra/emulated/fields_bls12381/doc.go new file mode 100644 index 0000000000..c94c08b683 --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/doc.go @@ -0,0 +1,7 @@ +// Package fields_bls12381 implements the fields arithmetic of the Fp12 tower +// used to compute the pairing over the BLS12-381 curve. +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-1-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +package fields_bls12381 diff --git a/std/algebra/emulated/fields_bls12381/e12.go b/std/algebra/emulated/fields_bls12381/e12.go new file mode 100644 index 0000000000..d452c4f0cb --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e12.go @@ -0,0 +1,199 @@ +package fields_bls12381 + +import ( + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" +) + +type E12 struct { + C0, C1 E6 +} + +type Ext12 struct { + *Ext6 +} + +func NewExt12(api frontend.API) *Ext12 { + return &Ext12{Ext6: NewExt6(api)} +} + +func (e Ext12) Add(x, y *E12) *E12 { + z0 := e.Ext6.Add(&x.C0, &y.C0) + z1 := e.Ext6.Add(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Sub(x, y *E12) *E12 { + z0 := e.Ext6.Sub(&x.C0, &y.C0) + z1 := e.Ext6.Sub(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Conjugate(x *E12) *E12 { + z1 := e.Ext6.Neg(&x.C1) + return &E12{ + C0: x.C0, + C1: *z1, + } +} + +func (e Ext12) Mul(x, y *E12) *E12 { + a := e.Ext6.Add(&x.C0, &x.C1) + b := e.Ext6.Add(&y.C0, &y.C1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&x.C0, &y.C0) + c := e.Ext6.Mul(&x.C1, &y.C1) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Zero() *E12 { + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) One() *E12 { + z000 := e.fp.One() + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *z000, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) IsZero(z *E12) frontend.Variable { + c0 := e.Ext6.IsZero(&z.C0) + c1 := e.Ext6.IsZero(&z.C1) + return e.api.And(c0, c1) +} + +func (e Ext12) Square(x *E12) *E12 { + c0 := e.Ext6.Sub(&x.C0, &x.C1) + c3 := e.Ext6.MulByNonResidue(&x.C1) + c3 = e.Ext6.Neg(c3) + c3 = e.Ext6.Add(&x.C0, c3) + c2 := e.Ext6.Mul(&x.C0, &x.C1) + c0 = e.Ext6.Mul(c0, c3) + c0 = e.Ext6.Add(c0, c2) + z1 := e.Ext6.Double(c2) + c2 = e.Ext6.MulByNonResidue(c2) + z0 := e.Ext6.Add(c0, c2) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) AssertIsEqual(x, y *E12) { + e.Ext6.AssertIsEqual(&x.C0, &y.C0) + e.Ext6.AssertIsEqual(&x.C1, &y.C1) +} + +func FromE12(y *bls12381.E12) E12 { + return E12{ + C0: FromE6(&y.C0), + C1: FromE6(&y.C1), + } + +} + +func (e Ext12) Inverse(x *E12) *E12 { + res, err := e.fp.NewHint(inverseE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext12) DivUnchecked(x, y *E12) *E12 { + res, err := e.fp.NewHint(divE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1, &y.C0.B0.A0, &y.C0.B0.A1, &y.C0.B1.A0, &y.C0.B1.A1, &y.C0.B2.A0, &y.C0.B2.A1, &y.C1.B0.A0, &y.C1.B0.A1, &y.C1.B1.A0, &y.C1.B1.A1, &y.C1.B2.A0, &y.C1.B2.A1) + + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext12) Select(selector frontend.Variable, z1, z0 *E12) *E12 { + c0 := e.Ext6.Select(selector, &z1.C0, &z0.C0) + c1 := e.Ext6.Select(selector, &z1.C1, &z0.C1) + return &E12{C0: *c0, C1: *c1} +} + +func (e Ext12) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E12) *E12 { + c0 := e.Ext6.Lookup2(s1, s2, &a.C0, &b.C0, &c.C0, &d.C0) + c1 := e.Ext6.Lookup2(s1, s2, &a.C1, &b.C1, &c.C1, &d.C1) + return &E12{C0: *c0, C1: *c1} +} diff --git a/std/algebra/emulated/fields_bls12381/e12_pairing.go b/std/algebra/emulated/fields_bls12381/e12_pairing.go new file mode 100644 index 0000000000..c096b0d2fd --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e12_pairing.go @@ -0,0 +1,280 @@ +package fields_bls12381 + +import "github.com/consensys/gnark/std/math/emulated" + +func (e Ext12) nSquareTorus(z *E6, n int) *E6 { + for i := 0; i < n; i++ { + z = e.SquareTorus(z) + } + return z +} + +// ExptHalfTorus set z to x^(t/2) in E6 and return z +// const t/2 uint64 = 7566188111470821376 // negative +func (e Ext12) ExptHalfTorus(x *E6) *E6 { + // FixedExp computation is derived from the addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1101 = 1 + _1100 + // _1101000 = _1101 << 3 + // _1101001 = 1 + _1101000 + // return ((_1101001 << 9 + 1) << 32 + 1) << 15 + // + // Operations: 62 squares 5 multiplies + // + // Generated by github.com/mmcloughlin/addchain v0.4.0. + + // Step 1: z = x^0x2 + z := e.SquareTorus(x) + + // Step 2: z = x^0x3 + z = e.MulTorus(x, z) + + z = e.SquareTorus(z) + z = e.SquareTorus(z) + + // Step 5: z = x^0xd + z = e.MulTorus(x, z) + + // Step 8: z = x^0x68 + z = e.nSquareTorus(z, 3) + + // Step 9: z = x^0x69 + z = e.MulTorus(x, z) + + // Step 18: z = x^0xd200 + z = e.nSquareTorus(z, 9) + + // Step 19: z = x^0xd201 + z = e.MulTorus(x, z) + + // Step 51: z = x^0xd20100000000 + z = e.nSquareTorus(z, 32) + + // Step 52: z = x^0xd20100000001 + z = e.MulTorus(x, z) + + // Step 67: z = x^0x6900800000008000 + z = e.nSquareTorus(z, 15) + + z = e.InverseTorus(z) // because tAbsVal is negative + + return z +} + +// ExptTorus set z to xᵗ in E6 and return z +// const t uint64 = 15132376222941642752 // negative +func (e Ext12) ExptTorus(x *E6) *E6 { + z := e.ExptHalfTorus(x) + z = e.SquareTorus(z) + return z +} + +// MulBy014 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: 0}, +// C1: E6{B0: 0, B1: 1, B2: 0}, +// } +func (e *Ext12) MulBy014(z *E12, c0, c1 *E2) *E12 { + + a := z.C0 + a = *e.MulBy01(&a, c0, c1) + + var b E6 + // Mul by E6{0, 1, 0} + b.B0 = *e.Ext2.MulByNonResidue(&z.C1.B2) + b.B2 = z.C1.B1 + b.B1 = z.C1.B0 + + one := e.Ext2.One() + d := e.Ext2.Add(c1, one) + + zC1 := e.Ext6.Add(&z.C1, &z.C0) + zC1 = e.Ext6.MulBy01(zC1, c0, d) + zC1 = e.Ext6.Sub(zC1, &a) + zC1 = e.Ext6.Sub(zC1, &b) + zC0 := e.Ext6.MulByNonResidue(&b) + zC0 = e.Ext6.Add(zC0, &a) + + return &E12{ + C0: *zC0, + C1: *zC1, + } +} + +// multiplies two E12 sparse element of the form: +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: 0}, +// C1: E6{B0: 0, B1: 1, B2: 0}, +// } +// +// and +// +// E12{ +// C0: E6{B0: d0, B1: d1, B2: 0}, +// C1: E6{B0: 0, B1: 1, B2: 0}, +// } +func (e Ext12) Mul014By014(d0, d1, c0, c1 *E2) *[5]E2 { + one := e.Ext2.One() + x0 := e.Ext2.Mul(c0, d0) + x1 := e.Ext2.Mul(c1, d1) + tmp := e.Ext2.Add(c0, one) + x04 := e.Ext2.Add(d0, one) + x04 = e.Ext2.Mul(x04, tmp) + x04 = e.Ext2.Sub(x04, x0) + x04 = e.Ext2.Sub(x04, one) + tmp = e.Ext2.Add(c0, c1) + x01 := e.Ext2.Add(d0, d1) + x01 = e.Ext2.Mul(x01, tmp) + x01 = e.Ext2.Sub(x01, x0) + x01 = e.Ext2.Sub(x01, x1) + tmp = e.Ext2.Add(c1, one) + x14 := e.Ext2.Add(d1, one) + x14 = e.Ext2.Mul(x14, tmp) + x14 = e.Ext2.Sub(x14, x1) + x14 = e.Ext2.Sub(x14, one) + + zC0B0 := e.Ext2.NonResidue() + zC0B0 = e.Ext2.Add(zC0B0, x0) + + return &[5]E2{*zC0B0, *x01, *x1, *x04, *x14} +} + +// MulBy01245 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: c2}, +// C1: E6{B0: 0, B1: c4, B2: c5}, +// } +func (e *Ext12) MulBy01245(z *E12, x *[5]E2) *E12 { + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: *e.Ext2.Zero(), B1: x[3], B2: x[4]} + a := e.Ext6.Add(&z.C0, &z.C1) + b := e.Ext6.Add(c0, c1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&z.C0, c0) + c := e.Ext6.MulBy12(&z.C1, &x[3], &x[4]) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +// Torus-based arithmetic: +// +// After the easy part of the final exponentiation the elements are in a proper +// subgroup of Fpk (E12) that coincides with some algebraic tori. The elements +// are in the torus Tk(Fp) and thus in each torus Tk/d(Fp^d) for d|k, d≠k. We +// take d=6. So the elements are in T2(Fp6). +// Let G_{q,2} = {m ∈ Fq^2 | m^(q+1) = 1} where q = p^6. +// When m.C1 = 0, then m.C0 must be 1 or −1. +// +// We recall the tower construction: +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-1-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v + +// CompressTorus compresses x ∈ E12 to (x.C0 + 1)/x.C1 ∈ E6 +func (e Ext12) CompressTorus(x *E12) *E6 { + // x ∈ G_{q,2} \ {-1,1} + y := e.Ext6.Add(&x.C0, e.Ext6.One()) + y = e.Ext6.DivUnchecked(y, &x.C1) + return y +} + +// DecompressTorus decompresses y ∈ E6 to (y+w)/(y-w) ∈ E12 +func (e Ext12) DecompressTorus(y *E6) *E12 { + var n, d E12 + one := e.Ext6.One() + n.C0 = *y + n.C1 = *one + d.C0 = *y + d.C1 = *e.Ext6.Neg(one) + + x := e.DivUnchecked(&n, &d) + return x +} + +// MulTorus multiplies two compressed elements y1, y2 ∈ E6 +// and returns (y1 * y2 + v)/(y1 + y2) +// N.B.: we use MulTorus in the final exponentiation throughout y1 ≠ -y2 always. +func (e Ext12) MulTorus(y1, y2 *E6) *E6 { + n := e.Ext6.Mul(y1, y2) + n.B1 = *e.Ext2.Add(&n.B1, e.Ext2.One()) + d := e.Ext6.Add(y1, y2) + y3 := e.Ext6.DivUnchecked(n, d) + return y3 +} + +// InverseTorus inverses a compressed elements y ∈ E6 +// and returns -y +func (e Ext12) InverseTorus(y *E6) *E6 { + return e.Ext6.Neg(y) +} + +// SquareTorus squares a compressed elements y ∈ E6 +// and returns (y + v/y)/2 +// +// It uses a hint to verify that (2x-y)y = v saving one E6 AssertIsEqual. +func (e Ext12) SquareTorus(y *E6) *E6 { + res, err := e.fp.NewHint(squareTorusHint, 6, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + sq := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // v = (2x-y)y + v := e.Ext6.Double(&sq) + v = e.Ext6.Sub(v, y) + v = e.Ext6.Mul(v, y) + + _v := E6{B0: *e.Ext2.Zero(), B1: *e.Ext2.One(), B2: *e.Ext2.Zero()} + e.Ext6.AssertIsEqual(v, &_v) + + return &sq + +} + +// FrobeniusTorus raises a compressed elements y ∈ E6 to the modulus p +// and returns y^p / v^((p-1)/2) +func (e Ext12) FrobeniusTorus(y *E6) *E6 { + t0 := e.Ext2.Conjugate(&y.B0) + t1 := e.Ext2.Conjugate(&y.B1) + t2 := e.Ext2.Conjugate(&y.B2) + t1 = e.Ext2.MulByNonResidue1Power2(t1) + t2 = e.Ext2.MulByNonResidue1Power4(t2) + + v0 := E2{emulated.ValueOf[emulated.BLS12381Fp]("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230"), emulated.ValueOf[emulated.BLS12381Fp]("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230")} + res := &E6{B0: *t0, B1: *t1, B2: *t2} + res = e.Ext6.MulBy0(res, &v0) + + return res +} + +// FrobeniusSquareTorus raises a compressed elements y ∈ E6 to the square modulus p^2 +// and returns y^(p^2) / v^((p^2-1)/2) +func (e Ext12) FrobeniusSquareTorus(y *E6) *E6 { + v0 := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + t0 := e.Ext2.MulByElement(&y.B0, &v0) + t1 := e.Ext2.MulByNonResidue2Power2(&y.B1) + t1 = e.Ext2.MulByElement(t1, &v0) + t2 := e.Ext2.MulByNonResidue2Power4(&y.B2) + t2 = e.Ext2.MulByElement(t2, &v0) + + return &E6{B0: *t0, B1: *t1, B2: *t2} +} diff --git a/std/algebra/emulated/fields_bls12381/e12_test.go b/std/algebra/emulated/fields_bls12381/e12_test.go new file mode 100644 index 0000000000..cba62ef2be --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e12_test.go @@ -0,0 +1,581 @@ +package fields_bls12381 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e12Add struct { + A, B, C E12 +} + +func (circuit *e12Add) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e12Add{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Sub struct { + A, B, C E12 +} + +func (circuit *e12Sub) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e12Sub{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Mul struct { + A, B, C E12 +} + +func (circuit *e12Mul) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e12Mul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Div struct { + A, B, C E12 +} + +func (circuit *e12Div) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e12Div{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Square struct { + A, C E12 +} + +func (circuit *e12Square) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e12Square{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Conjugate struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Conjugate) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e12Conjugate{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12Inverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Inverse) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e12Inverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12ExptTorus struct { + A E6 + C E12 `gnark:",public"` +} + +func (circuit *e12ExptTorus) Define(api frontend.API) error { + e := NewExt12(api) + z := e.ExptTorus(&circuit.A) + expected := e.DecompressTorus(z) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestFp12ExptTorus(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bls12381.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c.Expt(&a) + _a, _ := a.CompressTorus() + witness := e12ExptTorus{ + A: FromE6(&_a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12ExptTorus{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12MulBy014 struct { + A E12 `gnark:",public"` + W E12 + B, C E2 +} + +func (circuit *e12MulBy014) Define(api frontend.API) error { + e := NewExt12(api) + res := e.MulBy014(&circuit.A, &circuit.B, &circuit.C) + e.AssertIsEqual(res, &circuit.W) + return nil +} + +func TestFp12MulBy014(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, w bls12381.E12 + _, _ = a.SetRandom() + var one, b, c bls12381.E2 + one.SetOne() + _, _ = b.SetRandom() + _, _ = c.SetRandom() + w.Set(&a) + w.MulBy014(&b, &c, &one) + + witness := e12MulBy014{ + A: FromE12(&a), + B: FromE2(&b), + C: FromE2(&c), + W: FromE12(&w), + } + + err := test.IsSolved(&e12MulBy014{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +// Torus-based arithmetic +type torusCompress struct { + A E12 + C E6 `gnark:",public"` +} + +func (circuit *torusCompress) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.CompressTorus(&circuit.A) + e.Ext6.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusCompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bls12381.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c, _ := a.CompressTorus() + + witness := torusCompress{ + A: FromE12(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&torusCompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusDecompress struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusDecompress) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusDecompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bls12381.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + d, _ := a.CompressTorus() + c := d.DecompressTorus() + + witness := torusDecompress{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusDecompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusMul struct { + A E12 + B E12 + C E12 `gnark:",public"` +} + +func (circuit *torusMul) Define(api frontend.API) error { + e := NewExt12(api) + compressedA := e.CompressTorus(&circuit.A) + compressedB := e.CompressTorus(&circuit.B) + compressedAB := e.MulTorus(compressedA, compressedB) + expected := e.DecompressTorus(compressedAB) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusMul(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c, tmp bls12381.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + // put b in the cyclotomic subgroup + tmp.Conjugate(&b) + b.Inverse(&b) + tmp.Mul(&tmp, &b) + b.FrobeniusSquare(&tmp).Mul(&b, &tmp) + + // uncompressed mul + c.Mul(&a, &b) + + witness := torusMul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&torusMul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusInverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusInverse) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.InverseTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusInverse(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed inverse + c.Inverse(&a) + + witness := torusInverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusInverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobenius struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobenius) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobenius(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobenius + c.Frobenius(&a) + + witness := torusFrobenius{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobenius{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobeniusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobeniusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusSquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobeniusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobeniusSquare + c.FrobeniusSquare(&a) + + witness := torusFrobeniusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobeniusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.SquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bls12381.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed square + c.Square(&a) + + witness := torusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bls12381/e2.go b/std/algebra/emulated/fields_bls12381/e2.go new file mode 100644 index 0000000000..e36178b215 --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e2.go @@ -0,0 +1,344 @@ +package fields_bls12381 + +import ( + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +type curveF = emulated.Field[emulated.BLS12381Fp] +type baseEl = emulated.Element[emulated.BLS12381Fp] + +type E2 struct { + A0, A1 baseEl +} + +type Ext2 struct { + api frontend.API + fp *curveF + nonResidues map[int]map[int]*E2 +} + +func NewExt2(api frontend.API) *Ext2 { + fp, err := emulated.NewField[emulated.BLS12381Fp](api) + if err != nil { + panic(err) + } + pwrs := map[int]map[int]struct { + A0 string + A1 string + }{ + 1: { + 1: {"3850754370037169011952147076051364057158807420970682438676050522613628423219637725072182697113062777891589506424760", "151655185184498381465642749684540099398075398968325446656007613510403227271200139370504932015952886146304766135027"}, + 2: {"0", "4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436"}, + 3: {"1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257", "1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257"}, + 4: {"4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437", "0"}, + 5: {"877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230", "3125332594171059424908108096204648978570118281977575435832422631601824034463382777937621250592425535493320683825557"}, + }, + 2: { + 1: {"793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351", "0"}, + 2: {"793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350", "0"}, + 3: {"4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786", "0"}, + 4: {"4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436", "0"}, + 5: {"4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437", "0"}, + }, + } + nonResidues := make(map[int]map[int]*E2) + for pwr, v := range pwrs { + for coeff, v := range v { + el := E2{emulated.ValueOf[emulated.BLS12381Fp](v.A0), emulated.ValueOf[emulated.BLS12381Fp](v.A1)} + if nonResidues[pwr] == nil { + nonResidues[pwr] = make(map[int]*E2) + } + nonResidues[pwr][coeff] = &el + } + } + return &Ext2{api: api, fp: fp, nonResidues: nonResidues} +} + +func (e Ext2) MulByElement(x *E2, y *baseEl) *E2 { + z0 := e.fp.MulMod(&x.A0, y) + z1 := e.fp.MulMod(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) MulByConstElement(x *E2, y *big.Int) *E2 { + z0 := e.fp.MulConst(&x.A0, y) + z1 := e.fp.MulConst(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Conjugate(x *E2) *E2 { + z0 := x.A0 + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: z0, + A1: *z1, + } +} + +func (e Ext2) MulByNonResidueGeneric(x *E2, power, coef int) *E2 { + y := e.nonResidues[power][coef] + z := e.Mul(x, y) + return z +} + +// MulByNonResidue returns x*(1+u) +func (e Ext2) MulByNonResidue(x *E2) *E2 { + a := e.fp.Sub(&x.A0, &x.A1) + b := e.fp.Add(&x.A0, &x.A1) + + return &E2{ + A0: *a, + A1: *b, + } +} + +// MulByNonResidue1Power1 returns x*(1+u)^(1*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power1(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 1) +} + +// MulByNonResidue1Power2 returns x*(1+u)^(2*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power2(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + a := e.fp.MulMod(&x.A1, &element) + a = e.fp.Neg(a) + b := e.fp.MulMod(&x.A0, &element) + return &E2{ + A0: *a, + A1: *b, + } +} + +// MulByNonResidue1Power3 returns x*(1+u)^(3*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power3(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 3) +} + +// MulByNonResidue1Power4 returns x*(1+u)^(4*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power4(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue1Power5 returns x*(1+u)^(5*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power5(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 5) +} + +// MulByNonResidue2Power1 returns x*(1+u)^(1*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power1(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power2 returns x*(1+u)^(2*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power2(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power3 returns x*(1+u)^(3*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power3(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power4 returns x*(1+u)^(4*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power4(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power5 returns x*(1+u)^(5*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power5(x *E2) *E2 { + element := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +func (e Ext2) Mul(x, y *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Add(&y.A0, &y.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &y.A0) + c := e.fp.MulMod(&x.A1, &y.A1) + z1 := e.fp.Sub(a, b) + z1 = e.fp.Sub(z1, c) + z0 := e.fp.Sub(b, c) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Add(x, y *E2) *E2 { + z0 := e.fp.Add(&x.A0, &y.A0) + z1 := e.fp.Add(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Sub(x, y *E2) *E2 { + z0 := e.fp.Sub(&x.A0, &y.A0) + z1 := e.fp.Sub(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Neg(x *E2) *E2 { + z0 := e.fp.Neg(&x.A0) + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) One() *E2 { + z0 := e.fp.One() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Zero() *E2 { + z0 := e.fp.Zero() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} +func (e Ext2) IsZero(z *E2) frontend.Variable { + a0 := e.fp.IsZero(&z.A0) + a1 := e.fp.IsZero(&z.A1) + return e.api.And(a0, a1) +} + +// returns 1+u +func (e Ext2) NonResidue() *E2 { + one := e.fp.One() + return &E2{ + A0: *one, + A1: *one, + } +} + +func (e Ext2) Square(x *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Sub(&x.A0, &x.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &x.A1) + b = e.fp.MulConst(b, big.NewInt(2)) + return &E2{ + A0: *a, + A1: *b, + } +} + +func (e Ext2) Double(x *E2) *E2 { + two := big.NewInt(2) + z0 := e.fp.MulConst(&x.A0, two) + z1 := e.fp.MulConst(&x.A1, two) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) AssertIsEqual(x, y *E2) { + e.fp.AssertIsEqual(&x.A0, &y.A0) + e.fp.AssertIsEqual(&x.A1, &y.A1) +} + +func FromE2(y *bls12381.E2) E2 { + return E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](y.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](y.A1), + } +} + +func (e Ext2) Inverse(x *E2) *E2 { + res, err := e.fp.NewHint(inverseE2Hint, 2, &x.A0, &x.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E2{ + A0: *res[0], + A1: *res[1], + } + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext2) DivUnchecked(x, y *E2) *E2 { + res, err := e.fp.NewHint(divE2Hint, 2, &x.A0, &x.A1, &y.A0, &y.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E2{ + A0: *res[0], + A1: *res[1], + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext2) Select(selector frontend.Variable, z1, z0 *E2) *E2 { + a0 := e.fp.Select(selector, &z1.A0, &z0.A0) + a1 := e.fp.Select(selector, &z1.A1, &z0.A1) + return &E2{A0: *a0, A1: *a1} +} + +func (e Ext2) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E2) *E2 { + a0 := e.fp.Lookup2(s1, s2, &a.A0, &b.A0, &c.A0, &d.A0) + a1 := e.fp.Lookup2(s1, s2, &a.A1, &b.A1, &c.A1, &d.A1) + return &E2{A0: *a0, A1: *a1} +} diff --git a/std/algebra/emulated/fields_bls12381/e2_test.go b/std/algebra/emulated/fields_bls12381/e2_test.go new file mode 100644 index 0000000000..4065af437c --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e2_test.go @@ -0,0 +1,351 @@ +package fields_bls12381 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +type e2Add struct { + A, B, C E2 +} + +func (circuit *e2Add) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e2Add{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Sub struct { + A, B, C E2 +} + +func (circuit *e2Sub) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e2Sub{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Double struct { + A, C E2 +} + +func (circuit *e2Double) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Double(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDoubleFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Double(&a) + + witness := e2Double{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Double{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Mul struct { + A, B, C E2 +} + +func (circuit *e2Mul) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e2Mul{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Square struct { + A, C E2 +} + +func (circuit *e2Square) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e2Square{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Div struct { + A, B, C E2 +} + +func (circuit *e2Div) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e2Div{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByElement struct { + A E2 + B baseEl + C E2 `gnark:",public"` +} + +func (circuit *e2MulByElement) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByElement(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulByElement(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + var b fp.Element + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByElement(&a, &b) + + witness := e2MulByElement{ + A: FromE2(&a), + B: emulated.ValueOf[emulated.BLS12381Fp](b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByElement{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByNonResidue struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2MulByNonResidue) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp2ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e2MulByNonResidue{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Neg struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Neg) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e2Neg{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Conjugate struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Conjugate) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e2Conjugate{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Inverse struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Inverse) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E2 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e2Inverse{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bls12381/e6.go b/std/algebra/emulated/fields_bls12381/e6.go new file mode 100644 index 0000000000..85e522eeb4 --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e6.go @@ -0,0 +1,315 @@ +package fields_bls12381 + +import ( + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" +) + +type E6 struct { + B0, B1, B2 E2 +} + +type Ext6 struct { + *Ext2 +} + +func NewExt6(api frontend.API) *Ext6 { + return &Ext6{Ext2: NewExt2(api)} +} + +func (e Ext6) One() *E6 { + z0 := e.Ext2.One() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Zero() *E6 { + z0 := e.Ext2.Zero() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) IsZero(z *E6) frontend.Variable { + b0 := e.Ext2.IsZero(&z.B0) + b1 := e.Ext2.IsZero(&z.B1) + b2 := e.Ext2.IsZero(&z.B2) + return e.api.And(e.api.And(b0, b1), b2) +} + +func (e Ext6) Add(x, y *E6) *E6 { + z0 := e.Ext2.Add(&x.B0, &y.B0) + z1 := e.Ext2.Add(&x.B1, &y.B1) + z2 := e.Ext2.Add(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Neg(x *E6) *E6 { + z0 := e.Ext2.Neg(&x.B0) + z1 := e.Ext2.Neg(&x.B1) + z2 := e.Ext2.Neg(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Sub(x, y *E6) *E6 { + z0 := e.Ext2.Sub(&x.B0, &y.B0) + z1 := e.Ext2.Sub(&x.B1, &y.B1) + z2 := e.Ext2.Sub(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Mul(x, y *E6) *E6 { + t0 := e.Ext2.Mul(&x.B0, &y.B0) + t1 := e.Ext2.Mul(&x.B1, &y.B1) + t2 := e.Ext2.Mul(&x.B2, &y.B2) + c0 := e.Ext2.Add(&x.B1, &x.B2) + tmp := e.Ext2.Add(&y.B1, &y.B2) + c0 = e.Ext2.Mul(c0, tmp) + c0 = e.Ext2.Sub(c0, t1) + c0 = e.Ext2.Sub(c0, t2) + c0 = e.Ext2.MulByNonResidue(c0) + c0 = e.Ext2.Add(c0, t0) + c1 := e.Ext2.Add(&x.B0, &x.B1) + tmp = e.Ext2.Add(&y.B0, &y.B1) + c1 = e.Ext2.Mul(c1, tmp) + c1 = e.Ext2.Sub(c1, t0) + c1 = e.Ext2.Sub(c1, t1) + tmp = e.Ext2.MulByNonResidue(t2) + c1 = e.Ext2.Add(c1, tmp) + tmp = e.Ext2.Add(&x.B0, &x.B2) + c2 := e.Ext2.Add(&y.B0, &y.B2) + c2 = e.Ext2.Mul(c2, tmp) + c2 = e.Ext2.Sub(c2, t0) + c2 = e.Ext2.Sub(c2, t2) + c2 = e.Ext2.Add(c2, t1) + return &E6{ + B0: *c0, + B1: *c1, + B2: *c2, + } +} + +func (e Ext6) Double(x *E6) *E6 { + z0 := e.Ext2.Double(&x.B0) + z1 := e.Ext2.Double(&x.B1) + z2 := e.Ext2.Double(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Square(x *E6) *E6 { + c4 := e.Ext2.Mul(&x.B0, &x.B1) + c4 = e.Ext2.Double(c4) + c5 := e.Ext2.Square(&x.B2) + c1 := e.Ext2.MulByNonResidue(c5) + c1 = e.Ext2.Add(c1, c4) + c2 := e.Ext2.Sub(c4, c5) + c3 := e.Ext2.Square(&x.B0) + c4 = e.Ext2.Sub(&x.B0, &x.B1) + c4 = e.Ext2.Add(c4, &x.B2) + c5 = e.Ext2.Mul(&x.B1, &x.B2) + c5 = e.Ext2.Double(c5) + c4 = e.Ext2.Square(c4) + c0 := e.Ext2.MulByNonResidue(c5) + c0 = e.Ext2.Add(c0, c3) + z2 := e.Ext2.Add(c2, c4) + z2 = e.Ext2.Add(z2, c5) + z2 = e.Ext2.Sub(z2, c3) + z0 := c0 + z1 := c1 + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) MulByE2(x *E6, y *E2) *E6 { + z0 := e.Ext2.Mul(&x.B0, y) + z1 := e.Ext2.Mul(&x.B1, y) + z2 := e.Ext2.Mul(&x.B2, y) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +// MulBy12 multiplication by sparse element (0,b1,b2) +func (e Ext6) MulBy12(x *E6, b1, b2 *E2) *E6 { + t1 := e.Ext2.Mul(&x.B1, b1) + t2 := e.Ext2.Mul(&x.B2, b2) + c0 := e.Ext2.Add(&x.B1, &x.B2) + tmp := e.Ext2.Add(b1, b2) + c0 = e.Ext2.Mul(c0, tmp) + c0 = e.Ext2.Sub(c0, t1) + c0 = e.Ext2.Sub(c0, t2) + c0 = e.Ext2.MulByNonResidue(c0) + c1 := e.Ext2.Add(&x.B0, &x.B1) + c1 = e.Ext2.Mul(c1, b1) + c1 = e.Ext2.Sub(c1, t1) + tmp = e.Ext2.MulByNonResidue(t2) + c1 = e.Ext2.Add(c1, tmp) + tmp = e.Ext2.Add(&x.B0, &x.B2) + c2 := e.Ext2.Mul(b2, tmp) + c2 = e.Ext2.Sub(c2, t2) + c2 = e.Ext2.Add(c2, t1) + return &E6{ + B0: *c0, + B1: *c1, + B2: *c2, + } +} + +// MulBy0 multiplies z by an E6 sparse element of the form +// +// E6{ +// B0: c0, +// B1: 0, +// B2: 0, +// } +func (e Ext6) MulBy0(z *E6, c0 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + tmp := e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 := e.Ext2.Mul(c0, tmp) + t1 = e.Ext2.Sub(t1, a) + return &E6{ + B0: *a, + B1: *t1, + B2: *t2, + } +} + +// MulBy01 multiplication by sparse element (c0,c1,0) +func (e Ext6) MulBy01(z *E6, c0, c1 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + b := e.Ext2.Mul(&z.B1, c1) + tmp := e.Ext2.Add(&z.B1, &z.B2) + t0 := e.Ext2.Mul(c1, tmp) + t0 = e.Ext2.Sub(t0, b) + t0 = e.Ext2.MulByNonResidue(t0) + t0 = e.Ext2.Add(t0, a) + tmp = e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + t2 = e.Ext2.Add(t2, b) + t1 := e.Ext2.Add(c0, c1) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 = e.Ext2.Mul(t1, tmp) + t1 = e.Ext2.Sub(t1, a) + t1 = e.Ext2.Sub(t1, b) + return &E6{ + B0: *t0, + B1: *t1, + B2: *t2, + } +} + +func (e Ext6) MulByNonResidue(x *E6) *E6 { + z2, z1, z0 := &x.B1, &x.B0, &x.B2 + z0 = e.Ext2.MulByNonResidue(z0) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) AssertIsEqual(x, y *E6) { + e.Ext2.AssertIsEqual(&x.B0, &y.B0) + e.Ext2.AssertIsEqual(&x.B1, &y.B1) + e.Ext2.AssertIsEqual(&x.B2, &y.B2) +} + +func FromE6(y *bls12381.E6) E6 { + return E6{ + B0: FromE2(&y.B0), + B1: FromE2(&y.B1), + B2: FromE2(&y.B2), + } + +} + +func (e Ext6) Inverse(x *E6) *E6 { + res, err := e.fp.NewHint(inverseE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext6) DivUnchecked(x, y *E6) *E6 { + res, err := e.fp.NewHint(divE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext6) Select(selector frontend.Variable, z1, z0 *E6) *E6 { + b0 := e.Ext2.Select(selector, &z1.B0, &z0.B0) + b1 := e.Ext2.Select(selector, &z1.B1, &z0.B1) + b2 := e.Ext2.Select(selector, &z1.B2, &z0.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} + +func (e Ext6) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E6) *E6 { + b0 := e.Ext2.Lookup2(s1, s2, &a.B0, &b.B0, &c.B0, &d.B0) + b1 := e.Ext2.Lookup2(s1, s2, &a.B1, &b.B1, &c.B1, &d.B1) + b2 := e.Ext2.Lookup2(s1, s2, &a.B2, &b.B2, &c.B2, &d.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} diff --git a/std/algebra/emulated/fields_bls12381/e6_test.go b/std/algebra/emulated/fields_bls12381/e6_test.go new file mode 100644 index 0000000000..7964833bdf --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/e6_test.go @@ -0,0 +1,327 @@ +package fields_bls12381 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e6Add struct { + A, B, C E6 +} + +func (circuit *e6Add) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e6Add{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Sub struct { + A, B, C E6 +} + +func (circuit *e6Sub) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e6Sub{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Mul struct { + A, B, C E6 +} + +func (circuit *e6Mul) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e6Mul{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Square struct { + A, C E6 +} + +func (circuit *e6Square) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e6Square{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Div struct { + A, B, C E6 +} + +func (circuit *e6Div) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bls12381.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e6Div{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByNonResidue struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByNonResidue) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e6MulByNonResidue{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByE2 struct { + A E6 + B E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByE2) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByE2(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByE2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + var b bls12381.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByE2(&a, &b) + + witness := e6MulByE2{ + A: FromE6(&a), + B: FromE2(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByE2{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulBy01 struct { + A E6 + C0, C1 E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulBy01) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulBy01(&circuit.A, &circuit.C0, &circuit.C1) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6By01(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + var C0, C1 bls12381.E2 + _, _ = a.SetRandom() + _, _ = C0.SetRandom() + _, _ = C1.SetRandom() + c.Set(&a) + c.MulBy01(&C0, &C1) + + witness := e6MulBy01{ + A: FromE6(&a), + C0: FromE2(&C0), + C1: FromE2(&C1), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulBy01{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Neg struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Neg) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e6Neg{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e6Inverse struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Inverse) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bls12381.E6 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e6Inverse{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bls12381/hints.go b/std/algebra/emulated/fields_bls12381/hints.go new file mode 100644 index 0000000000..c8455f40cf --- /dev/null +++ b/std/algebra/emulated/fields_bls12381/hints.go @@ -0,0 +1,238 @@ +package fields_bls12381 + +import ( + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/math/emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + // E2 + divE2Hint, + inverseE2Hint, + // E6 + divE6Hint, + inverseE6Hint, + squareTorusHint, + // E12 + divE12Hint, + inverseE12Hint, + } +} + +func inverseE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + + c.Inverse(&a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +func divE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bls12381.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + b.A0.SetBigInt(inputs[2]) + b.A1.SetBigInt(inputs[3]) + + c.Inverse(&b).Mul(&c, &a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +// E6 hints +func inverseE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + c.Inverse(&a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func divE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bls12381.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + b.B0.A0.SetBigInt(inputs[6]) + b.B0.A1.SetBigInt(inputs[7]) + b.B1.A0.SetBigInt(inputs[8]) + b.B1.A1.SetBigInt(inputs[9]) + b.B2.A0.SetBigInt(inputs[10]) + b.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&b).Mul(&c, &a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func squareTorusHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + _c := a.DecompressTorus() + _c.CyclotomicSquare(&_c) + c, _ = _c.CompressTorus() + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +// E12 hints +func inverseE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bls12381.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} + +func divE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bls12381.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + b.C0.B0.A0.SetBigInt(inputs[12]) + b.C0.B0.A1.SetBigInt(inputs[13]) + b.C0.B1.A0.SetBigInt(inputs[14]) + b.C0.B1.A1.SetBigInt(inputs[15]) + b.C0.B2.A0.SetBigInt(inputs[16]) + b.C0.B2.A1.SetBigInt(inputs[17]) + b.C1.B0.A0.SetBigInt(inputs[18]) + b.C1.B0.A1.SetBigInt(inputs[19]) + b.C1.B1.A0.SetBigInt(inputs[20]) + b.C1.B1.A1.SetBigInt(inputs[21]) + b.C1.B2.A0.SetBigInt(inputs[22]) + b.C1.B2.A1.SetBigInt(inputs[23]) + + c.Inverse(&b).Mul(&c, &a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} diff --git a/std/algebra/emulated/fields_bn254/doc.go b/std/algebra/emulated/fields_bn254/doc.go new file mode 100644 index 0000000000..ae94b3c243 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/doc.go @@ -0,0 +1,7 @@ +// Package fields_bn254 implements the fields arithmetic of the Fp12 tower +// used to compute the pairing over the BN254 curve. +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-9-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +package fields_bn254 diff --git a/std/algebra/emulated/fields_bn254/e12.go b/std/algebra/emulated/fields_bn254/e12.go new file mode 100644 index 0000000000..abe8bd9a9e --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e12.go @@ -0,0 +1,199 @@ +package fields_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" +) + +type E12 struct { + C0, C1 E6 +} + +type Ext12 struct { + *Ext6 +} + +func NewExt12(api frontend.API) *Ext12 { + return &Ext12{Ext6: NewExt6(api)} +} + +func (e Ext12) Add(x, y *E12) *E12 { + z0 := e.Ext6.Add(&x.C0, &y.C0) + z1 := e.Ext6.Add(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Sub(x, y *E12) *E12 { + z0 := e.Ext6.Sub(&x.C0, &y.C0) + z1 := e.Ext6.Sub(&x.C1, &y.C1) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Conjugate(x *E12) *E12 { + z1 := e.Ext6.Neg(&x.C1) + return &E12{ + C0: x.C0, + C1: *z1, + } +} + +func (e Ext12) Mul(x, y *E12) *E12 { + a := e.Ext6.Add(&x.C0, &x.C1) + b := e.Ext6.Add(&y.C0, &y.C1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&x.C0, &y.C0) + c := e.Ext6.Mul(&x.C1, &y.C1) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) Zero() *E12 { + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) One() *E12 { + z000 := e.fp.One() + zero := e.fp.Zero() + return &E12{ + C0: E6{ + B0: E2{A0: *z000, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + C1: E6{ + B0: E2{A0: *zero, A1: *zero}, + B1: E2{A0: *zero, A1: *zero}, + B2: E2{A0: *zero, A1: *zero}, + }, + } +} + +func (e Ext12) IsZero(z *E12) frontend.Variable { + c0 := e.Ext6.IsZero(&z.C0) + c1 := e.Ext6.IsZero(&z.C1) + return e.api.And(c0, c1) +} + +func (e Ext12) Square(x *E12) *E12 { + c0 := e.Ext6.Sub(&x.C0, &x.C1) + c3 := e.Ext6.MulByNonResidue(&x.C1) + c3 = e.Ext6.Neg(c3) + c3 = e.Ext6.Add(&x.C0, c3) + c2 := e.Ext6.Mul(&x.C0, &x.C1) + c0 = e.Ext6.Mul(c0, c3) + c0 = e.Ext6.Add(c0, c2) + z1 := e.Ext6.Double(c2) + c2 = e.Ext6.MulByNonResidue(c2) + z0 := e.Ext6.Add(c0, c2) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +func (e Ext12) AssertIsEqual(x, y *E12) { + e.Ext6.AssertIsEqual(&x.C0, &y.C0) + e.Ext6.AssertIsEqual(&x.C1, &y.C1) +} + +func FromE12(y *bn254.E12) E12 { + return E12{ + C0: FromE6(&y.C0), + C1: FromE6(&y.C1), + } + +} + +func (e Ext12) Inverse(x *E12) *E12 { + res, err := e.fp.NewHint(inverseE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext12) DivUnchecked(x, y *E12) *E12 { + res, err := e.fp.NewHint(divE12Hint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1, &y.C0.B0.A0, &y.C0.B0.A1, &y.C0.B1.A0, &y.C0.B1.A1, &y.C0.B2.A0, &y.C0.B2.A1, &y.C1.B0.A0, &y.C1.B0.A1, &y.C1.B1.A0, &y.C1.B1.A1, &y.C1.B2.A0, &y.C1.B2.A1) + + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E12{ + C0: E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + }, + C1: E6{ + B0: E2{A0: *res[6], A1: *res[7]}, + B1: E2{A0: *res[8], A1: *res[9]}, + B2: E2{A0: *res[10], A1: *res[11]}, + }, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext12) Select(selector frontend.Variable, z1, z0 *E12) *E12 { + c0 := e.Ext6.Select(selector, &z1.C0, &z0.C0) + c1 := e.Ext6.Select(selector, &z1.C1, &z0.C1) + return &E12{C0: *c0, C1: *c1} +} + +func (e Ext12) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E12) *E12 { + c0 := e.Ext6.Lookup2(s1, s2, &a.C0, &b.C0, &c.C0, &d.C0) + c1 := e.Ext6.Lookup2(s1, s2, &a.C1, &b.C1, &c.C1, &d.C1) + return &E12{C0: *c0, C1: *c1} +} diff --git a/std/algebra/emulated/fields_bn254/e12_pairing.go b/std/algebra/emulated/fields_bn254/e12_pairing.go new file mode 100644 index 0000000000..678b438eb8 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e12_pairing.go @@ -0,0 +1,350 @@ +package fields_bn254 + +import ( + "github.com/consensys/gnark/std/math/emulated" +) + +func (e Ext12) nSquareTorus(z *E6, n int) *E6 { + for i := 0; i < n; i++ { + z = e.SquareTorus(z) + } + return z +} + +// Exponentiation by the seed t=4965661367192848881 +// The computations are performed on E6 compressed form using Torus-based arithmetic. +func (e Ext12) ExptTorus(x *E6) *E6 { + // Expt computation is derived from the addition chain: + // + // _10 = 2*1 + // _100 = 2*_10 + // _1000 = 2*_100 + // _10000 = 2*_1000 + // _10001 = 1 + _10000 + // _10011 = _10 + _10001 + // _10100 = 1 + _10011 + // _11001 = _1000 + _10001 + // _100010 = 2*_10001 + // _100111 = _10011 + _10100 + // _101001 = _10 + _100111 + // i27 = (_100010 << 6 + _100 + _11001) << 7 + _11001 + // i44 = (i27 << 8 + _101001 + _10) << 6 + _10001 + // i70 = ((i44 << 8 + _101001) << 6 + _101001) << 10 + // return (_100111 + i70) << 6 + _101001 + _1000 + // + // Operations: 62 squares 17 multiplies + // + // Generated by github.com/mmcloughlin/addchain v0.4.0. + + t3 := e.SquareTorus(x) + t5 := e.SquareTorus(t3) + result := e.SquareTorus(t5) + t0 := e.SquareTorus(result) + t2 := e.MulTorus(x, t0) + t0 = e.MulTorus(t3, t2) + t1 := e.MulTorus(x, t0) + t4 := e.MulTorus(result, t2) + t6 := e.SquareTorus(t2) + t1 = e.MulTorus(t0, t1) + t0 = e.MulTorus(t3, t1) + t6 = e.nSquareTorus(t6, 6) + t5 = e.MulTorus(t5, t6) + t5 = e.MulTorus(t4, t5) + t5 = e.nSquareTorus(t5, 7) + t4 = e.MulTorus(t4, t5) + t4 = e.nSquareTorus(t4, 8) + t4 = e.MulTorus(t0, t4) + t3 = e.MulTorus(t3, t4) + t3 = e.nSquareTorus(t3, 6) + t2 = e.MulTorus(t2, t3) + t2 = e.nSquareTorus(t2, 8) + t2 = e.MulTorus(t0, t2) + t2 = e.nSquareTorus(t2, 6) + t2 = e.MulTorus(t0, t2) + t2 = e.nSquareTorus(t2, 10) + t1 = e.MulTorus(t1, t2) + t1 = e.nSquareTorus(t1, 6) + t0 = e.MulTorus(t0, t1) + z := e.MulTorus(result, t0) + return z +} + +// Square034 squares an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +func (e *Ext12) Square034(x *E12) *E12 { + c0 := E6{ + B0: *e.Ext2.Sub(&x.C0.B0, &x.C1.B0), + B1: *e.Ext2.Neg(&x.C1.B1), + B2: *e.Ext2.Zero(), + } + + c3 := E6{ + B0: x.C0.B0, + B1: *e.Ext2.Neg(&x.C1.B0), + B2: *e.Ext2.Neg(&x.C1.B1), + } + + c2 := E6{ + B0: x.C1.B0, + B1: x.C1.B1, + B2: *e.Ext2.Zero(), + } + c3 = *e.MulBy01(&c3, &c0.B0, &c0.B1) + c3 = *e.Ext6.Add(&c3, &c2) + + var z E12 + z.C1.B0 = *e.Ext2.Add(&c2.B0, &c2.B0) + z.C1.B1 = *e.Ext2.Add(&c2.B1, &c2.B1) + + z.C0.B0 = c3.B0 + z.C0.B1 = *e.Ext2.Add(&c3.B1, &c2.B0) + z.C0.B2 = *e.Ext2.Add(&c3.B2, &c2.B1) + + return &z +} + +// MulBy034 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +func (e *Ext12) MulBy034(z *E12, c3, c4 *E2) *E12 { + + a := z.C0 + b := z.C1 + b = *e.MulBy01(&b, c3, c4) + c3 = e.Ext2.Add(e.Ext2.One(), c3) + d := e.Ext6.Add(&z.C0, &z.C1) + d = e.MulBy01(d, c3, c4) + + zC1 := e.Ext6.Add(&a, &b) + zC1 = e.Ext6.Neg(zC1) + zC1 = e.Ext6.Add(zC1, d) + zC0 := e.Ext6.MulByNonResidue(&b) + zC0 = e.Ext6.Add(zC0, &a) + + return &E12{ + C0: *zC0, + C1: *zC1, + } +} + +// multiplies two E12 sparse element of the form: +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +// +// and +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: d3, B1: d4, B2: 0}, +// } +func (e *Ext12) Mul034By034(d3, d4, c3, c4 *E2) *[5]E2 { + x3 := e.Ext2.Mul(c3, d3) + x4 := e.Ext2.Mul(c4, d4) + x04 := e.Ext2.Add(c4, d4) + x03 := e.Ext2.Add(c3, d3) + tmp := e.Ext2.Add(c3, c4) + x34 := e.Ext2.Add(d3, d4) + x34 = e.Ext2.Mul(x34, tmp) + x34 = e.Ext2.Sub(x34, x3) + x34 = e.Ext2.Sub(x34, x4) + + zC0B0 := e.Ext2.MulByNonResidue(x4) + zC0B0 = e.Ext2.Add(zC0B0, e.Ext2.One()) + zC0B1 := x3 + zC0B2 := x34 + zC1B0 := x03 + zC1B1 := x04 + + return &[5]E2{*zC0B0, *zC0B1, *zC0B2, *zC1B0, *zC1B1} +} + +// MulBy01234 multiplies z by an E12 sparse element of the form +// +// E12{ +// C0: E6{B0: c0, B1: c1, B2: c2}, +// C1: E6{B0: c3, B1: c4, B2: 0}, +// } +func (e *Ext12) MulBy01234(z *E12, x *[5]E2) *E12 { + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: *e.Ext2.Zero()} + a := e.Ext6.Add(&z.C0, &z.C1) + b := e.Ext6.Add(c0, c1) + a = e.Ext6.Mul(a, b) + b = e.Ext6.Mul(&z.C0, c0) + c := e.Ext6.MulBy01(&z.C1, &x[3], &x[4]) + z1 := e.Ext6.Sub(a, b) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, b) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +// multiplies two E12 sparse element of the form: +// +// E12{ +// C0: E6{B0: x0, B1: x1, B2: x2}, +// C1: E6{B0: x3, B1: x4, B2: 0}, +// } +// +// and +// +// E12{ +// C0: E6{B0: 1, B1: 0, B2: 0}, +// C1: E6{B0: z3, B1: z4, B2: 0}, +// } +func (e *Ext12) Mul01234By034(x *[5]E2, z3, z4 *E2) *E12 { + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: *e.Ext2.Zero()} + a := e.Ext6.Add(e.Ext6.One(), &E6{B0: *z3, B1: *z4, B2: *e.Ext2.Zero()}) + b := e.Ext6.Add(c0, c1) + a = e.Ext6.Mul(a, b) + c := e.Ext6.Mul01By01(z3, z4, &x[3], &x[4]) + z1 := e.Ext6.Sub(a, c0) + z1 = e.Ext6.Sub(z1, c) + z0 := e.Ext6.MulByNonResidue(c) + z0 = e.Ext6.Add(z0, c0) + return &E12{ + C0: *z0, + C1: *z1, + } +} + +// Torus-based arithmetic: +// +// After the easy part of the final exponentiation the elements are in a proper +// subgroup of Fpk (E12) that coincides with some algebraic tori. The elements +// are in the torus Tk(Fp) and thus in each torus Tk/d(Fp^d) for d|k, d≠k. We +// take d=6. So the elements are in T2(Fp6). +// Let G_{q,2} = {m ∈ Fq^2 | m^(q+1) = 1} where q = p^6. +// When m.C1 = 0, then m.C0 must be 1 or −1. +// +// We recall the tower construction: +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-9-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v + +// CompressTorus compresses x ∈ E12 to (x.C0 + 1)/x.C1 ∈ E6 +func (e Ext12) CompressTorus(x *E12) *E6 { + // x ∈ G_{q,2} \ {-1,1} + y := e.Ext6.Add(&x.C0, e.Ext6.One()) + y = e.Ext6.DivUnchecked(y, &x.C1) + return y +} + +// DecompressTorus decompresses y ∈ E6 to (y+w)/(y-w) ∈ E12 +func (e Ext12) DecompressTorus(y *E6) *E12 { + var n, d E12 + one := e.Ext6.One() + n.C0 = *y + n.C1 = *one + d.C0 = *y + d.C1 = *e.Ext6.Neg(one) + + x := e.DivUnchecked(&n, &d) + return x +} + +// MulTorus multiplies two compressed elements y1, y2 ∈ E6 +// and returns (y1 * y2 + v)/(y1 + y2) +// N.B.: we use MulTorus in the final exponentiation throughout y1 ≠ -y2 always. +func (e Ext12) MulTorus(y1, y2 *E6) *E6 { + n := e.Ext6.Mul(y1, y2) + n.B1 = *e.Ext2.Add(&n.B1, e.Ext2.One()) + d := e.Ext6.Add(y1, y2) + y3 := e.Ext6.DivUnchecked(n, d) + return y3 +} + +// InverseTorus inverses a compressed elements y ∈ E6 +// and returns -y +func (e Ext12) InverseTorus(y *E6) *E6 { + return e.Ext6.Neg(y) +} + +// SquareTorus squares a compressed elements y ∈ E6 +// and returns (y + v/y)/2 +// +// It uses a hint to verify that (2x-y)y = v saving one E6 AssertIsEqual. +func (e Ext12) SquareTorus(y *E6) *E6 { + res, err := e.fp.NewHint(squareTorusHint, 6, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + sq := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // v = (2x-y)y + v := e.Ext6.Double(&sq) + v = e.Ext6.Sub(v, y) + v = e.Ext6.Mul(v, y) + + _v := E6{B0: *e.Ext2.Zero(), B1: *e.Ext2.One(), B2: *e.Ext2.Zero()} + e.Ext6.AssertIsEqual(v, &_v) + + return &sq + +} + +// FrobeniusTorus raises a compressed elements y ∈ E6 to the modulus p +// and returns y^p / v^((p-1)/2) +func (e Ext12) FrobeniusTorus(y *E6) *E6 { + t0 := e.Ext2.Conjugate(&y.B0) + t1 := e.Ext2.Conjugate(&y.B1) + t2 := e.Ext2.Conjugate(&y.B2) + t1 = e.Ext2.MulByNonResidue1Power2(t1) + t2 = e.Ext2.MulByNonResidue1Power4(t2) + + v0 := E2{emulated.ValueOf[emulated.BN254Fp]("18566938241244942414004596690298913868373833782006617400804628704885040364344"), emulated.ValueOf[emulated.BN254Fp]("5722266937896532885780051958958348231143373700109372999374820235121374419868")} + res := &E6{B0: *t0, B1: *t1, B2: *t2} + res = e.Ext6.MulBy0(res, &v0) + + return res +} + +// FrobeniusSquareTorus raises a compressed elements y ∈ E6 to the square modulus p^2 +// and returns y^(p^2) / v^((p^2-1)/2) +func (e Ext12) FrobeniusSquareTorus(y *E6) *E6 { + v0 := emulated.ValueOf[emulated.BN254Fp]("2203960485148121921418603742825762020974279258880205651967") + t0 := e.Ext2.MulByElement(&y.B0, &v0) + t1 := e.Ext2.MulByNonResidue2Power2(&y.B1) + t1 = e.Ext2.MulByElement(t1, &v0) + t2 := e.Ext2.MulByNonResidue2Power4(&y.B2) + t2 = e.Ext2.MulByElement(t2, &v0) + + return &E6{B0: *t0, B1: *t1, B2: *t2} +} + +// FrobeniusCubeTorus raises a compressed elements y ∈ E6 to the cube modulus p^3 +// and returns y^(p^3) / v^((p^3-1)/2) +func (e Ext12) FrobeniusCubeTorus(y *E6) *E6 { + t0 := e.Ext2.Conjugate(&y.B0) + t1 := e.Ext2.Conjugate(&y.B1) + t2 := e.Ext2.Conjugate(&y.B2) + t1 = e.Ext2.MulByNonResidue3Power2(t1) + t2 = e.Ext2.MulByNonResidue3Power4(t2) + + v0 := E2{emulated.ValueOf[emulated.BN254Fp]("10190819375481120917420622822672549775783927716138318623895010788866272024264"), emulated.ValueOf[emulated.BN254Fp]("303847389135065887422783454877609941456349188919719272345083954437860409601")} + res := &E6{B0: *t0, B1: *t1, B2: *t2} + res = e.Ext6.MulBy0(res, &v0) + + return res +} diff --git a/std/algebra/emulated/fields_bn254/e12_test.go b/std/algebra/emulated/fields_bn254/e12_test.go new file mode 100644 index 0000000000..a3289b4698 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e12_test.go @@ -0,0 +1,620 @@ +package fields_bn254 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e12Add struct { + A, B, C E12 +} + +func (circuit *e12Add) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e12Add{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Sub struct { + A, B, C E12 +} + +func (circuit *e12Sub) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e12Sub{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Mul struct { + A, B, C E12 +} + +func (circuit *e12Mul) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e12Mul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Div struct { + A, B, C E12 +} + +func (circuit *e12Div) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e12Div{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Square struct { + A, C E12 +} + +func (circuit *e12Square) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e12Square{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e12Conjugate struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Conjugate) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e12Conjugate{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12Inverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *e12Inverse) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp12(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e12Inverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12ExptTorus struct { + A E6 + C E12 `gnark:",public"` +} + +func (circuit *e12ExptTorus) Define(api frontend.API) error { + e := NewExt12(api) + z := e.ExptTorus(&circuit.A) + expected := e.DecompressTorus(z) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestFp12ExptTorus(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bn254.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c.Expt(&a) + _a, _ := a.CompressTorus() + witness := e12ExptTorus{ + A: FromE6(&_a), + C: FromE12(&c), + } + + err := test.IsSolved(&e12ExptTorus{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e12MulBy034 struct { + A E12 `gnark:",public"` + W E12 + B, C E2 +} + +func (circuit *e12MulBy034) Define(api frontend.API) error { + e := NewExt12(api) + res := e.MulBy034(&circuit.A, &circuit.B, &circuit.C) + e.AssertIsEqual(res, &circuit.W) + return nil +} + +func TestFp12MulBy034(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, w bn254.E12 + _, _ = a.SetRandom() + var one, b, c bn254.E2 + one.SetOne() + _, _ = b.SetRandom() + _, _ = c.SetRandom() + w.Set(&a) + w.MulBy034(&one, &b, &c) + + witness := e12MulBy034{ + A: FromE12(&a), + B: FromE2(&b), + C: FromE2(&c), + W: FromE12(&w), + } + + err := test.IsSolved(&e12MulBy034{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +// Torus-based arithmetic +type torusCompress struct { + A E12 + C E6 `gnark:",public"` +} + +func (circuit *torusCompress) Define(api frontend.API) error { + e := NewExt12(api) + expected := e.CompressTorus(&circuit.A) + e.Ext6.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusCompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bn254.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + c, _ := a.CompressTorus() + + witness := torusCompress{ + A: FromE12(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&torusCompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusDecompress struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusDecompress) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusDecompress(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + var tmp bn254.E12 + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + d, _ := a.CompressTorus() + c := d.DecompressTorus() + + witness := torusDecompress{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusDecompress{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusMul struct { + A E12 + B E12 + C E12 `gnark:",public"` +} + +func (circuit *torusMul) Define(api frontend.API) error { + e := NewExt12(api) + compressedA := e.CompressTorus(&circuit.A) + compressedB := e.CompressTorus(&circuit.B) + compressedAB := e.MulTorus(compressedA, compressedB) + expected := e.DecompressTorus(compressedAB) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusMul(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c, tmp bn254.E12 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + // put b in the cyclotomic subgroup + tmp.Conjugate(&b) + b.Inverse(&b) + tmp.Mul(&tmp, &b) + b.FrobeniusSquare(&tmp).Mul(&b, &tmp) + + // uncompressed mul + c.Mul(&a, &b) + + witness := torusMul{ + A: FromE12(&a), + B: FromE12(&b), + C: FromE12(&c), + } + + err := test.IsSolved(&torusMul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusInverse struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusInverse) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.InverseTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusInverse(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed inverse + c.Inverse(&a) + + witness := torusInverse{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusInverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobenius struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobenius) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobenius(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobenius + c.Frobenius(&a) + + witness := torusFrobenius{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobenius{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobeniusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobeniusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusSquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobeniusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobeniusSquare + c.FrobeniusSquare(&a) + + witness := torusFrobeniusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobeniusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusFrobeniusCube struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusFrobeniusCube) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.FrobeniusCubeTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusFrobeniusCube(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed frobeniusCube + c.FrobeniusCube(&a) + + witness := torusFrobeniusCube{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusFrobeniusCube{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type torusSquare struct { + A E12 + C E12 `gnark:",public"` +} + +func (circuit *torusSquare) Define(api frontend.API) error { + e := NewExt12(api) + compressed := e.CompressTorus(&circuit.A) + compressed = e.SquareTorus(compressed) + expected := e.DecompressTorus(compressed) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestTorusSquare(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c, tmp bn254.E12 + _, _ = a.SetRandom() + + // put a in the cyclotomic subgroup + tmp.Conjugate(&a) + a.Inverse(&a) + tmp.Mul(&tmp, &a) + a.FrobeniusSquare(&tmp).Mul(&a, &tmp) + + // uncompressed square + c.Square(&a) + + witness := torusSquare{ + A: FromE12(&a), + C: FromE12(&c), + } + + err := test.IsSolved(&torusSquare{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bn254/e2.go b/std/algebra/emulated/fields_bn254/e2.go new file mode 100644 index 0000000000..161712f197 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e2.go @@ -0,0 +1,353 @@ +package fields_bn254 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +type curveF = emulated.Field[emulated.BN254Fp] +type baseEl = emulated.Element[emulated.BN254Fp] + +type E2 struct { + A0, A1 baseEl +} + +type Ext2 struct { + api frontend.API + fp *curveF + nonResidues map[int]map[int]*E2 +} + +func NewExt2(api frontend.API) *Ext2 { + fp, err := emulated.NewField[emulated.BN254Fp](api) + if err != nil { + // TODO: we start returning errors when generifying + panic(err) + } + pwrs := map[int]map[int]struct { + A0 string + A1 string + }{ + 1: { + 1: {"8376118865763821496583973867626364092589906065868298776909617916018768340080", "16469823323077808223889137241176536799009286646108169935659301613961712198316"}, + 2: {"21575463638280843010398324269430826099269044274347216827212613867836435027261", "10307601595873709700152284273816112264069230130616436755625194854815875713954"}, + 3: {"2821565182194536844548159561693502659359617185244120367078079554186484126554", "3505843767911556378687030309984248845540243509899259641013678093033130930403"}, + 4: {"2581911344467009335267311115468803099551665605076196740867805258568234346338", "19937756971775647987995932169929341994314640652964949448313374472400716661030"}, + 5: {"685108087231508774477564247770172212460312782337200605669322048753928464687", "8447204650696766136447902020341177575205426561248465145919723016860428151883"}, + }, + 3: { + 1: {"11697423496358154304825782922584725312912383441159505038794027105778954184319", "303847389135065887422783454877609941456349188919719272345083954437860409601"}, + 2: {"3772000881919853776433695186713858239009073593817195771773381919316419345261", "2236595495967245188281701248203181795121068902605861227855261137820944008926"}, + 3: {"19066677689644738377698246183563772429336693972053703295610958340458742082029", "18382399103927718843559375435273026243156067647398564021675359801612095278180"}, + 4: {"5324479202449903542726783395506214481928257762400643279780343368557297135718", "16208900380737693084919495127334387981393726419856888799917914180988844123039"}, + 5: {"8941241848238582420466759817324047081148088512956452953208002715982955420483", "10338197737521362862238855242243140895517409139741313354160881284257516364953"}, + }, + } + nonResidues := make(map[int]map[int]*E2) + for pwr, v := range pwrs { + for coeff, v := range v { + el := E2{emulated.ValueOf[emulated.BN254Fp](v.A0), emulated.ValueOf[emulated.BN254Fp](v.A1)} + if nonResidues[pwr] == nil { + nonResidues[pwr] = make(map[int]*E2) + } + nonResidues[pwr][coeff] = &el + } + } + return &Ext2{api: api, fp: fp, nonResidues: nonResidues} +} + +func (e Ext2) MulByElement(x *E2, y *baseEl) *E2 { + z0 := e.fp.MulMod(&x.A0, y) + z1 := e.fp.MulMod(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) MulByConstElement(x *E2, y *big.Int) *E2 { + z0 := e.fp.MulConst(&x.A0, y) + z1 := e.fp.MulConst(&x.A1, y) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Conjugate(x *E2) *E2 { + z0 := x.A0 + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: z0, + A1: *z1, + } +} + +func (e Ext2) MulByNonResidueGeneric(x *E2, power, coef int) *E2 { + y := e.nonResidues[power][coef] + z := e.Mul(x, y) + return z +} + +// MulByNonResidue return x*(9+u) +func (e Ext2) MulByNonResidue(x *E2) *E2 { + nine := big.NewInt(9) + a := e.fp.MulConst(&x.A0, nine) + a = e.fp.Sub(a, &x.A1) + b := e.fp.MulConst(&x.A1, nine) + b = e.fp.Add(b, &x.A0) + return &E2{ + A0: *a, + A1: *b, + } +} + +// MulByNonResidue1Power1 returns x*(9+u)^(1*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power1(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 1) +} + +// MulByNonResidue1Power2 returns x*(9+u)^(2*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power2(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 2) +} + +// MulByNonResidue1Power3 returns x*(9+u)^(3*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power3(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 3) +} + +// MulByNonResidue1Power4 returns x*(9+u)^(4*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power4(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 4) +} + +// MulByNonResidue1Power5 returns x*(9+u)^(5*(p^1-1)/6) +func (e Ext2) MulByNonResidue1Power5(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 1, 5) +} + +// MulByNonResidue2Power1 returns x*(9+u)^(1*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power1(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("21888242871839275220042445260109153167277707414472061641714758635765020556617") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power2 returns x*(9+u)^(2*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power2(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("21888242871839275220042445260109153167277707414472061641714758635765020556616") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power3 returns x*(9+u)^(3*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power3(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("21888242871839275222246405745257275088696311157297823662689037894645226208582") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power4 returns x*(9+u)^(4*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power4(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("2203960485148121921418603742825762020974279258880205651966") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue2Power5 returns x*(9+u)^(5*(p^2-1)/6) +func (e Ext2) MulByNonResidue2Power5(x *E2) *E2 { + element := emulated.ValueOf[emulated.BN254Fp]("2203960485148121921418603742825762020974279258880205651967") + return &E2{ + A0: *e.fp.MulMod(&x.A0, &element), + A1: *e.fp.MulMod(&x.A1, &element), + } +} + +// MulByNonResidue3Power1 returns x*(9+u)^(1*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power1(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 1) +} + +// MulByNonResidue3Power2 returns x*(9+u)^(2*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power2(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 2) +} + +// MulByNonResidue3Power3 returns x*(9+u)^(3*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power3(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 3) +} + +// MulByNonResidue3Power4 returns x*(9+u)^(4*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power4(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 4) +} + +// MulByNonResidue3Power5 returns x*(9+u)^(5*(p^3-1)/6) +func (e Ext2) MulByNonResidue3Power5(x *E2) *E2 { + return e.MulByNonResidueGeneric(x, 3, 5) +} + +func (e Ext2) Mul(x, y *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Add(&y.A0, &y.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &y.A0) + c := e.fp.MulMod(&x.A1, &y.A1) + z1 := e.fp.Sub(a, b) + z1 = e.fp.Sub(z1, c) + z0 := e.fp.Sub(b, c) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Add(x, y *E2) *E2 { + z0 := e.fp.Add(&x.A0, &y.A0) + z1 := e.fp.Add(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Sub(x, y *E2) *E2 { + z0 := e.fp.Sub(&x.A0, &y.A0) + z1 := e.fp.Sub(&x.A1, &y.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Neg(x *E2) *E2 { + z0 := e.fp.Neg(&x.A0) + z1 := e.fp.Neg(&x.A1) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) One() *E2 { + z0 := e.fp.One() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) Zero() *E2 { + z0 := e.fp.Zero() + z1 := e.fp.Zero() + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) IsZero(z *E2) frontend.Variable { + a0 := e.fp.IsZero(&z.A0) + a1 := e.fp.IsZero(&z.A1) + return e.api.And(a0, a1) +} + +func (e Ext2) Square(x *E2) *E2 { + a := e.fp.Add(&x.A0, &x.A1) + b := e.fp.Sub(&x.A0, &x.A1) + a = e.fp.MulMod(a, b) + b = e.fp.MulMod(&x.A0, &x.A1) + b = e.fp.MulConst(b, big.NewInt(2)) + return &E2{ + A0: *a, + A1: *b, + } +} + +func (e Ext2) Double(x *E2) *E2 { + two := big.NewInt(2) + z0 := e.fp.MulConst(&x.A0, two) + z1 := e.fp.MulConst(&x.A1, two) + return &E2{ + A0: *z0, + A1: *z1, + } +} + +func (e Ext2) AssertIsEqual(x, y *E2) { + e.fp.AssertIsEqual(&x.A0, &y.A0) + e.fp.AssertIsEqual(&x.A1, &y.A1) +} + +func FromE2(y *bn254.E2) E2 { + return E2{ + A0: emulated.ValueOf[emulated.BN254Fp](y.A0), + A1: emulated.ValueOf[emulated.BN254Fp](y.A1), + } +} + +func (e Ext2) Inverse(x *E2) *E2 { + res, err := e.fp.NewHint(inverseE2Hint, 2, &x.A0, &x.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E2{ + A0: *res[0], + A1: *res[1], + } + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext2) DivUnchecked(x, y *E2) *E2 { + res, err := e.fp.NewHint(divE2Hint, 2, &x.A0, &x.A1, &y.A0, &y.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E2{ + A0: *res[0], + A1: *res[1], + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext2) Select(selector frontend.Variable, z1, z0 *E2) *E2 { + a0 := e.fp.Select(selector, &z1.A0, &z0.A0) + a1 := e.fp.Select(selector, &z1.A1, &z0.A1) + return &E2{A0: *a0, A1: *a1} +} + +func (e Ext2) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E2) *E2 { + a0 := e.fp.Lookup2(s1, s2, &a.A0, &b.A0, &c.A0, &d.A0) + a1 := e.fp.Lookup2(s1, s2, &a.A1, &b.A1, &c.A1, &d.A1) + return &E2{A0: *a0, A1: *a1} +} diff --git a/std/algebra/emulated/fields_bn254/e2_test.go b/std/algebra/emulated/fields_bn254/e2_test.go new file mode 100644 index 0000000000..55c2564b02 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e2_test.go @@ -0,0 +1,351 @@ +package fields_bn254 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +type e2Add struct { + A, B, C E2 +} + +func (circuit *e2Add) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e2Add{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Sub struct { + A, B, C E2 +} + +func (circuit *e2Sub) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e2Sub{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Double struct { + A, C E2 +} + +func (circuit *e2Double) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Double(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDoubleFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Double(&a) + + witness := e2Double{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Double{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Mul struct { + A, B, C E2 +} + +func (circuit *e2Mul) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e2Mul{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Square struct { + A, C E2 +} + +func (circuit *e2Square) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e2Square{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Div struct { + A, B, C E2 +} + +func (circuit *e2Div) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e2Div{ + A: FromE2(&a), + B: FromE2(&b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByElement struct { + A E2 + B baseEl + C E2 `gnark:",public"` +} + +func (circuit *e2MulByElement) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByElement(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulByElement(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + var b fp.Element + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByElement(&a, &b) + + witness := e2MulByElement{ + A: FromE2(&a), + B: emulated.ValueOf[emulated.BN254Fp](b), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByElement{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2MulByNonResidue struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2MulByNonResidue) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp2ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e2MulByNonResidue{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e2Neg struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Neg) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e2Neg{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Conjugate struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Conjugate) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Conjugate(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestConjugateFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Conjugate(&a) + + witness := e2Conjugate{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Conjugate{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e2Inverse struct { + A E2 + C E2 `gnark:",public"` +} + +func (circuit *e2Inverse) Define(api frontend.API) error { + e := NewExt2(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E2 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e2Inverse{ + A: FromE2(&a), + C: FromE2(&c), + } + + err := test.IsSolved(&e2Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bn254/e6.go b/std/algebra/emulated/fields_bn254/e6.go new file mode 100644 index 0000000000..4330fa49fb --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e6.go @@ -0,0 +1,338 @@ +package fields_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" +) + +type E6 struct { + B0, B1, B2 E2 +} + +type Ext6 struct { + *Ext2 +} + +func NewExt6(api frontend.API) *Ext6 { + return &Ext6{Ext2: NewExt2(api)} +} + +func (e Ext6) One() *E6 { + z0 := e.Ext2.One() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Zero() *E6 { + z0 := e.Ext2.Zero() + z1 := e.Ext2.Zero() + z2 := e.Ext2.Zero() + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) IsZero(z *E6) frontend.Variable { + b0 := e.Ext2.IsZero(&z.B0) + b1 := e.Ext2.IsZero(&z.B1) + b2 := e.Ext2.IsZero(&z.B2) + return e.api.And(e.api.And(b0, b1), b2) +} + +func (e Ext6) Add(x, y *E6) *E6 { + z0 := e.Ext2.Add(&x.B0, &y.B0) + z1 := e.Ext2.Add(&x.B1, &y.B1) + z2 := e.Ext2.Add(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Neg(x *E6) *E6 { + z0 := e.Ext2.Neg(&x.B0) + z1 := e.Ext2.Neg(&x.B1) + z2 := e.Ext2.Neg(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Sub(x, y *E6) *E6 { + z0 := e.Ext2.Sub(&x.B0, &y.B0) + z1 := e.Ext2.Sub(&x.B1, &y.B1) + z2 := e.Ext2.Sub(&x.B2, &y.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Mul(x, y *E6) *E6 { + t0 := e.Ext2.Mul(&x.B0, &y.B0) + t1 := e.Ext2.Mul(&x.B1, &y.B1) + t2 := e.Ext2.Mul(&x.B2, &y.B2) + c0 := e.Ext2.Add(&x.B1, &x.B2) + tmp := e.Ext2.Add(&y.B1, &y.B2) + c0 = e.Ext2.Mul(c0, tmp) + c0 = e.Ext2.Sub(c0, t1) + c0 = e.Ext2.Sub(c0, t2) + c0 = e.Ext2.MulByNonResidue(c0) + c0 = e.Ext2.Add(c0, t0) + c1 := e.Ext2.Add(&x.B0, &x.B1) + tmp = e.Ext2.Add(&y.B0, &y.B1) + c1 = e.Ext2.Mul(c1, tmp) + c1 = e.Ext2.Sub(c1, t0) + c1 = e.Ext2.Sub(c1, t1) + tmp = e.Ext2.MulByNonResidue(t2) + c1 = e.Ext2.Add(c1, tmp) + tmp = e.Ext2.Add(&x.B0, &x.B2) + c2 := e.Ext2.Add(&y.B0, &y.B2) + c2 = e.Ext2.Mul(c2, tmp) + c2 = e.Ext2.Sub(c2, t0) + c2 = e.Ext2.Sub(c2, t2) + c2 = e.Ext2.Add(c2, t1) + return &E6{ + B0: *c0, + B1: *c1, + B2: *c2, + } +} + +func (e Ext6) Double(x *E6) *E6 { + z0 := e.Ext2.Double(&x.B0) + z1 := e.Ext2.Double(&x.B1) + z2 := e.Ext2.Double(&x.B2) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) Square(x *E6) *E6 { + c4 := e.Ext2.Mul(&x.B0, &x.B1) + c4 = e.Ext2.Double(c4) + c5 := e.Ext2.Square(&x.B2) + c1 := e.Ext2.MulByNonResidue(c5) + c1 = e.Ext2.Add(c1, c4) + c2 := e.Ext2.Sub(c4, c5) + c3 := e.Ext2.Square(&x.B0) + c4 = e.Ext2.Sub(&x.B0, &x.B1) + c4 = e.Ext2.Add(c4, &x.B2) + c5 = e.Ext2.Mul(&x.B1, &x.B2) + c5 = e.Ext2.Double(c5) + c4 = e.Ext2.Square(c4) + c0 := e.Ext2.MulByNonResidue(c5) + c0 = e.Ext2.Add(c0, c3) + z2 := e.Ext2.Add(c2, c4) + z2 = e.Ext2.Add(z2, c5) + z2 = e.Ext2.Sub(z2, c3) + z0 := c0 + z1 := c1 + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) MulByE2(x *E6, y *E2) *E6 { + z0 := e.Ext2.Mul(&x.B0, y) + z1 := e.Ext2.Mul(&x.B1, y) + z2 := e.Ext2.Mul(&x.B2, y) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +// MulBy0 multiplies z by an E6 sparse element of the form +// +// E6{ +// B0: c0, +// B1: 0, +// B2: 0, +// } +func (e Ext6) MulBy0(z *E6, c0 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + tmp := e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 := e.Ext2.Mul(c0, tmp) + t1 = e.Ext2.Sub(t1, a) + return &E6{ + B0: *a, + B1: *t1, + B2: *t2, + } +} + +// MulBy01 multiplies z by an E6 sparse element of the form +// +// E6{ +// B0: c0, +// B1: c1, +// B2: 0, +// } +func (e Ext6) MulBy01(z *E6, c0, c1 *E2) *E6 { + a := e.Ext2.Mul(&z.B0, c0) + b := e.Ext2.Mul(&z.B1, c1) + tmp := e.Ext2.Add(&z.B1, &z.B2) + t0 := e.Ext2.Mul(c1, tmp) + t0 = e.Ext2.Sub(t0, b) + t0 = e.Ext2.MulByNonResidue(t0) + t0 = e.Ext2.Add(t0, a) + tmp = e.Ext2.Add(&z.B0, &z.B2) + t2 := e.Ext2.Mul(c0, tmp) + t2 = e.Ext2.Sub(t2, a) + t2 = e.Ext2.Add(t2, b) + t1 := e.Ext2.Add(c0, c1) + tmp = e.Ext2.Add(&z.B0, &z.B1) + t1 = e.Ext2.Mul(t1, tmp) + t1 = e.Ext2.Sub(t1, a) + t1 = e.Ext2.Sub(t1, b) + return &E6{ + B0: *t0, + B1: *t1, + B2: *t2, + } +} + +// Mul01By01 multiplies two E6 sparse element of the form: +// +// E6{ +// B0: c0, +// B1: c1, +// B2: 0, +// } +// +// and +// +// E6{ +// B0: d0, +// B1: d1, +// B2: 0, +// } +func (e Ext6) Mul01By01(c0, c1, d0, d1 *E2) *E6 { + a := e.Ext2.Mul(d0, c0) + b := e.Ext2.Mul(d1, c1) + t0 := e.Ext2.Mul(c1, d1) + t0 = e.Ext2.Sub(t0, b) + t0 = e.Ext2.MulByNonResidue(t0) + t0 = e.Ext2.Add(t0, a) + t2 := e.Ext2.Mul(c0, d0) + t2 = e.Ext2.Sub(t2, a) + t2 = e.Ext2.Add(t2, b) + t1 := e.Ext2.Add(c0, c1) + tmp := e.Ext2.Add(d0, d1) + t1 = e.Ext2.Mul(t1, tmp) + t1 = e.Ext2.Sub(t1, a) + t1 = e.Ext2.Sub(t1, b) + return &E6{ + B0: *t0, + B1: *t1, + B2: *t2, + } +} + +func (e Ext6) MulByNonResidue(x *E6) *E6 { + z2, z1, z0 := &x.B1, &x.B0, &x.B2 + z0 = e.Ext2.MulByNonResidue(z0) + return &E6{ + B0: *z0, + B1: *z1, + B2: *z2, + } +} + +func (e Ext6) FrobeniusSquare(x *E6) *E6 { + z01 := e.Ext2.MulByNonResidue2Power2(&x.B1) + z02 := e.Ext2.MulByNonResidue2Power4(&x.B2) + return &E6{B0: x.B0, B1: *z01, B2: *z02} +} + +func (e Ext6) AssertIsEqual(x, y *E6) { + e.Ext2.AssertIsEqual(&x.B0, &y.B0) + e.Ext2.AssertIsEqual(&x.B1, &y.B1) + e.Ext2.AssertIsEqual(&x.B2, &y.B2) +} + +func FromE6(y *bn254.E6) E6 { + return E6{ + B0: FromE2(&y.B0), + B1: FromE2(&y.B1), + B2: FromE2(&y.B2), + } + +} + +func (e Ext6) Inverse(x *E6) *E6 { + res, err := e.fp.NewHint(inverseE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + inv := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + one := e.One() + + // 1 == inv * x + _one := e.Mul(&inv, x) + e.AssertIsEqual(one, _one) + + return &inv + +} + +func (e Ext6) DivUnchecked(x, y *E6) *E6 { + res, err := e.fp.NewHint(divE6Hint, 6, &x.B0.A0, &x.B0.A1, &x.B1.A0, &x.B1.A1, &x.B2.A0, &x.B2.A1, &y.B0.A0, &y.B0.A1, &y.B1.A0, &y.B1.A1, &y.B2.A0, &y.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + div := E6{ + B0: E2{A0: *res[0], A1: *res[1]}, + B1: E2{A0: *res[2], A1: *res[3]}, + B2: E2{A0: *res[4], A1: *res[5]}, + } + + // x == div * y + _x := e.Mul(&div, y) + e.AssertIsEqual(x, _x) + + return &div +} + +func (e Ext6) Select(selector frontend.Variable, z1, z0 *E6) *E6 { + b0 := e.Ext2.Select(selector, &z1.B0, &z0.B0) + b1 := e.Ext2.Select(selector, &z1.B1, &z0.B1) + b2 := e.Ext2.Select(selector, &z1.B2, &z0.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} + +func (e Ext6) Lookup2(s1, s2 frontend.Variable, a, b, c, d *E6) *E6 { + b0 := e.Ext2.Lookup2(s1, s2, &a.B0, &b.B0, &c.B0, &d.B0) + b1 := e.Ext2.Lookup2(s1, s2, &a.B1, &b.B1, &c.B1, &d.B1) + b2 := e.Ext2.Lookup2(s1, s2, &a.B2, &b.B2, &c.B2, &d.B2) + return &E6{B0: *b0, B1: *b1, B2: *b2} +} diff --git a/std/algebra/emulated/fields_bn254/e6_test.go b/std/algebra/emulated/fields_bn254/e6_test.go new file mode 100644 index 0000000000..6c46bb7d46 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/e6_test.go @@ -0,0 +1,363 @@ +package fields_bn254 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type e6Add struct { + A, B, C E6 +} + +func (circuit *e6Add) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Add(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestAddFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Add(&a, &b) + + witness := e6Add{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Add{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Sub struct { + A, B, C E6 +} + +func (circuit *e6Sub) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Sub(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSubFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Sub(&a, &b) + + witness := e6Sub{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Sub{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Mul struct { + A, B, C E6 +} + +func (circuit *e6Mul) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestMulFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Mul(&a, &b) + + witness := e6Mul{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Mul{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Square struct { + A, C E6 +} + +func (circuit *e6Square) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Square(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestSquareFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e6Square{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Square{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Div struct { + A, B, C E6 +} + +func (circuit *e6Div) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.DivUnchecked(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + return nil +} + +func TestDivFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b, c bn254.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.Div(&a, &b) + + witness := e6Div{ + A: FromE6(&a), + B: FromE6(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Div{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByNonResidue struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByNonResidue) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByNonResidue(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByNonResidue(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.MulByNonResidue(&a) + + witness := e6MulByNonResidue{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByNonResidue{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulByE2 struct { + A E6 + B E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulByE2) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulByE2(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6ByE2(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + var b bn254.E2 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + c.MulByE2(&a, &b) + + witness := e6MulByE2{ + A: FromE6(&a), + B: FromE2(&b), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulByE2{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulBy01 struct { + A E6 + C0, C1 E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulBy01) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulBy01(&circuit.A, &circuit.C0, &circuit.C1) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6By01(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + var C0, C1 bn254.E2 + _, _ = a.SetRandom() + _, _ = C0.SetRandom() + _, _ = C1.SetRandom() + c.Set(&a) + c.MulBy01(&C0, &C1) + + witness := e6MulBy01{ + A: FromE6(&a), + C0: FromE2(&C0), + C1: FromE2(&C1), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulBy01{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6MulBy0 struct { + A E6 + C0 E2 + C E6 `gnark:",public"` +} + +func (circuit *e6MulBy0) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.MulBy0(&circuit.A, &circuit.C0) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestMulFp6By0(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + var C0, zero bn254.E2 + _, _ = a.SetRandom() + _, _ = C0.SetRandom() + c.Set(&a) + c.MulBy01(&C0, &zero) + + witness := e6MulBy0{ + A: FromE6(&a), + C0: FromE2(&C0), + C: FromE6(&c), + } + + err := test.IsSolved(&e6MulBy0{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +type e6Neg struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Neg) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Neg(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestNegFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.Neg(&a) + + witness := e6Neg{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Neg{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e6Inverse struct { + A E6 + C E6 `gnark:",public"` +} + +func (circuit *e6Inverse) Define(api frontend.API) error { + e := NewExt6(api) + expected := e.Inverse(&circuit.A) + e.AssertIsEqual(expected, &circuit.C) + + return nil +} + +func TestInverseFp6(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bn254.E6 + _, _ = a.SetRandom() + c.Inverse(&a) + + witness := e6Inverse{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6Inverse{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/fields_bn254/hints.go b/std/algebra/emulated/fields_bn254/hints.go new file mode 100644 index 0000000000..b08d6ed977 --- /dev/null +++ b/std/algebra/emulated/fields_bn254/hints.go @@ -0,0 +1,238 @@ +package fields_bn254 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/math/emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + // E2 + divE2Hint, + inverseE2Hint, + // E6 + divE6Hint, + inverseE6Hint, + squareTorusHint, + // E12 + divE12Hint, + inverseE12Hint, + } +} + +func inverseE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + + c.Inverse(&a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +func divE2Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bn254.E2 + + a.A0.SetBigInt(inputs[0]) + a.A1.SetBigInt(inputs[1]) + b.A0.SetBigInt(inputs[2]) + b.A1.SetBigInt(inputs[3]) + + c.Inverse(&b).Mul(&c, &a) + + c.A0.BigInt(outputs[0]) + c.A1.BigInt(outputs[1]) + + return nil + }) +} + +// E6 hints +func inverseE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + c.Inverse(&a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func divE6Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bn254.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + b.B0.A0.SetBigInt(inputs[6]) + b.B0.A1.SetBigInt(inputs[7]) + b.B1.A0.SetBigInt(inputs[8]) + b.B1.A1.SetBigInt(inputs[9]) + b.B2.A0.SetBigInt(inputs[10]) + b.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&b).Mul(&c, &a) + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +func squareTorusHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E6 + + a.B0.A0.SetBigInt(inputs[0]) + a.B0.A1.SetBigInt(inputs[1]) + a.B1.A0.SetBigInt(inputs[2]) + a.B1.A1.SetBigInt(inputs[3]) + a.B2.A0.SetBigInt(inputs[4]) + a.B2.A1.SetBigInt(inputs[5]) + + _c := a.DecompressTorus() + _c.CyclotomicSquare(&_c) + c, _ = _c.CompressTorus() + + c.B0.A0.BigInt(outputs[0]) + c.B0.A1.BigInt(outputs[1]) + c.B1.A0.BigInt(outputs[2]) + c.B1.A1.BigInt(outputs[3]) + c.B2.A0.BigInt(outputs[4]) + c.B2.A1.BigInt(outputs[5]) + + return nil + }) +} + +// E12 hints +func inverseE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, c bn254.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + c.Inverse(&a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} + +func divE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHint(nativeInputs, nativeOutputs, + func(mod *big.Int, inputs, outputs []*big.Int) error { + var a, b, c bn254.E12 + + a.C0.B0.A0.SetBigInt(inputs[0]) + a.C0.B0.A1.SetBigInt(inputs[1]) + a.C0.B1.A0.SetBigInt(inputs[2]) + a.C0.B1.A1.SetBigInt(inputs[3]) + a.C0.B2.A0.SetBigInt(inputs[4]) + a.C0.B2.A1.SetBigInt(inputs[5]) + a.C1.B0.A0.SetBigInt(inputs[6]) + a.C1.B0.A1.SetBigInt(inputs[7]) + a.C1.B1.A0.SetBigInt(inputs[8]) + a.C1.B1.A1.SetBigInt(inputs[9]) + a.C1.B2.A0.SetBigInt(inputs[10]) + a.C1.B2.A1.SetBigInt(inputs[11]) + + b.C0.B0.A0.SetBigInt(inputs[12]) + b.C0.B0.A1.SetBigInt(inputs[13]) + b.C0.B1.A0.SetBigInt(inputs[14]) + b.C0.B1.A1.SetBigInt(inputs[15]) + b.C0.B2.A0.SetBigInt(inputs[16]) + b.C0.B2.A1.SetBigInt(inputs[17]) + b.C1.B0.A0.SetBigInt(inputs[18]) + b.C1.B0.A1.SetBigInt(inputs[19]) + b.C1.B1.A0.SetBigInt(inputs[20]) + b.C1.B1.A1.SetBigInt(inputs[21]) + b.C1.B2.A0.SetBigInt(inputs[22]) + b.C1.B2.A1.SetBigInt(inputs[23]) + + c.Inverse(&b).Mul(&c, &a) + + c.C0.B0.A0.BigInt(outputs[0]) + c.C0.B0.A1.BigInt(outputs[1]) + c.C0.B1.A0.BigInt(outputs[2]) + c.C0.B1.A1.BigInt(outputs[3]) + c.C0.B2.A0.BigInt(outputs[4]) + c.C0.B2.A1.BigInt(outputs[5]) + c.C1.B0.A0.BigInt(outputs[6]) + c.C1.B0.A1.BigInt(outputs[7]) + c.C1.B1.A0.BigInt(outputs[8]) + c.C1.B1.A1.BigInt(outputs[9]) + c.C1.B2.A0.BigInt(outputs[10]) + c.C1.B2.A1.BigInt(outputs[11]) + + return nil + }) +} diff --git a/std/algebra/emulated/sw_bls12381/doc.go b/std/algebra/emulated/sw_bls12381/doc.go new file mode 100644 index 0000000000..948f61e24f --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/doc.go @@ -0,0 +1,6 @@ +// Package sw_bls12381 implements G1 and G2 arithmetics and pairing computation over BLS12-381 curve. +// +// The implementation follows [Housni22]: "Pairings in Rank-1 Constraint Systems". +// +// [Housni22]: https://eprint.iacr.org/2022/1162 +package sw_bls12381 diff --git a/std/algebra/emulated/sw_bls12381/doc_test.go b/std/algebra/emulated/sw_bls12381/doc_test.go new file mode 100644 index 0000000000..a1ce0f5ca5 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/doc_test.go @@ -0,0 +1,105 @@ +package sw_bls12381_test + +import ( + "crypto/rand" + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" +) + +type PairCircuit struct { + InG1 sw_bls12381.G1Affine + InG2 sw_bls12381.G2Affine + Res sw_bls12381.GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := sw_bls12381.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + // Pair method does not check that the points are in the proper groups. + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + // Compute the pairing + res, err := pairing.Pair([]*sw_bls12381.G1Affine{&c.InG1}, []*sw_bls12381.G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func ExamplePairing() { + p, q, err := randomG1G2Affines() + if err != nil { + panic(err) + } + res, err := bls12381.Pair([]bls12381.G1Affine{p}, []bls12381.G2Affine{q}) + if err != nil { + panic(err) + } + circuit := PairCircuit{} + witness := PairCircuit{ + InG1: sw_bls12381.NewG1Affine(p), + InG2: sw_bls12381.NewG2Affine(q), + Res: sw_bls12381.NewGTEl(res), + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } +} + +func randomG1G2Affines() (p bls12381.G1Affine, q bls12381.G2Affine, err error) { + _, _, G1AffGen, G2AffGen := bls12381.Generators() + mod := bls12381.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + p.ScalarMultiplication(&G1AffGen, s1) + q.ScalarMultiplication(&G2AffGen, s2) + return +} diff --git a/std/algebra/emulated/sw_bls12381/g1.go b/std/algebra/emulated/sw_bls12381/g1.go new file mode 100644 index 0000000000..f7908b5ded --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/g1.go @@ -0,0 +1,45 @@ +package sw_bls12381 + +import ( + "fmt" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type G1Affine = sw_emulated.AffinePoint[emulated.BLS12381Fp] + +func NewG1Affine(v bls12381.G1Affine) G1Affine { + return G1Affine{ + X: emulated.ValueOf[emulated.BLS12381Fp](v.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](v.Y), + } +} + +type G1 struct { + curveF *emulated.Field[emulated.BLS12381Fp] + w *emulated.Element[emulated.BLS12381Fp] +} + +func NewG1(api frontend.API) (*G1, error) { + ba, err := emulated.NewField[emulated.BLS12381Fp](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + w := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + return &G1{ + curveF: ba, + w: &w, + }, nil +} + +func (g1 *G1) phi(q *G1Affine) *G1Affine { + x := g1.curveF.Mul(&q.X, g1.w) + + return &G1Affine{ + X: *x, + Y: q.Y, + } +} diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go new file mode 100644 index 0000000000..45077a6cba --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -0,0 +1,219 @@ +package sw_bls12381 + +import ( + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/math/emulated" +) + +type G2 struct { + *fields_bls12381.Ext2 + u1, w *emulated.Element[emulated.BLS12381Fp] + v *fields_bls12381.E2 +} + +type G2Affine struct { + X, Y fields_bls12381.E2 +} + +func NewG2(api frontend.API) *G2 { + w := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436") + u1 := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437") + v := fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp]("2973677408986561043442465346520108879172042883009249989176415018091420807192182638567116318576472649347015917690530"), + A1: emulated.ValueOf[emulated.BLS12381Fp]("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257"), + } + return &G2{ + Ext2: fields_bls12381.NewExt2(api), + w: &w, + u1: &u1, + v: &v, + } +} + +func NewG2Affine(v bls12381.G2Affine) G2Affine { + return G2Affine{ + X: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.X.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.X.A1), + }, + Y: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.Y.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.Y.A1), + }, + } +} + +func (g2 *G2) psi(q *G2Affine) *G2Affine { + x := g2.Ext2.MulByElement(&q.X, g2.u1) + y := g2.Ext2.Conjugate(&q.Y) + y = g2.Ext2.Mul(y, g2.v) + + return &G2Affine{ + X: fields_bls12381.E2{A0: x.A1, A1: x.A0}, + Y: *y, + } +} + +func (g2 *G2) scalarMulBySeed(q *G2Affine) *G2Affine { + + z := g2.triple(q) + z = g2.double(z) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 2) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 8) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 31) + z = g2.doubleAndAdd(z, q) + z = g2.doubleN(z, 16) + + return g2.neg(z) +} + +func (g2 G2) add(p, q *G2Affine) *G2Affine { + // compute λ = (q.y-p.y)/(q.x-p.x) + qypy := g2.Ext2.Sub(&q.Y, &p.Y) + qxpx := g2.Ext2.Sub(&q.X, &p.X) + λ := g2.Ext2.DivUnchecked(qypy, qxpx) + + // xr = λ²-p.x-q.x + λλ := g2.Ext2.Square(λ) + qxpx = g2.Ext2.Add(&p.X, &q.X) + xr := g2.Ext2.Sub(λλ, qxpx) + + // p.y = λ(p.x-r.x) - p.y + pxrx := g2.Ext2.Sub(&p.X, xr) + λpxrx := g2.Ext2.Mul(λ, pxrx) + yr := g2.Ext2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) neg(p *G2Affine) *G2Affine { + xr := &p.X + yr := g2.Ext2.Neg(&p.Y) + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) sub(p, q *G2Affine) *G2Affine { + qNeg := g2.neg(q) + return g2.add(p, qNeg) +} + +func (g2 *G2) double(p *G2Affine) *G2Affine { + // compute λ = (3p.x²)/2*p.y + xx3a := g2.Square(&p.X) + xx3a = g2.MulByConstElement(xx3a, big.NewInt(3)) + y2 := g2.Double(&p.Y) + λ := g2.DivUnchecked(xx3a, y2) + + // xr = λ²-2p.x + x2 := g2.Double(&p.X) + λλ := g2.Square(λ) + xr := g2.Sub(λλ, x2) + + // yr = λ(p-xr) - p.y + pxrx := g2.Sub(&p.X, xr) + λpxrx := g2.Mul(λ, pxrx) + yr := g2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 *G2) doubleN(p *G2Affine, n int) *G2Affine { + pn := p + for s := 0; s < n; s++ { + pn = g2.double(pn) + } + return pn +} + +func (g2 G2) triple(p *G2Affine) *G2Affine { + + // compute λ1 = (3p.x²)/2p.y + xx := g2.Square(&p.X) + xx = g2.MulByConstElement(xx, big.NewInt(3)) + y2 := g2.Double(&p.Y) + λ1 := g2.DivUnchecked(xx, y2) + + // xr = λ1²-2p.x + x2 := g2.MulByConstElement(&p.X, big.NewInt(2)) + λ1λ1 := g2.Square(λ1) + x2 = g2.Sub(λ1λ1, x2) + + // ommit y2 computation, and + // compute λ2 = 2p.y/(x2 − p.x) − λ1. + x1x2 := g2.Sub(&p.X, x2) + λ2 := g2.DivUnchecked(y2, x1x2) + λ2 = g2.Sub(λ2, λ1) + + // xr = λ²-p.x-x2 + λ2λ2 := g2.Square(λ2) + qxrx := g2.Add(x2, &p.X) + xr := g2.Sub(λ2λ2, qxrx) + + // yr = λ(p.x-xr) - p.y + pxrx := g2.Sub(&p.X, xr) + λ2pxrx := g2.Mul(λ2, pxrx) + yr := g2.Sub(λ2pxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) doubleAndAdd(p, q *G2Affine) *G2Affine { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := g2.Ext2.Sub(&q.Y, &p.Y) + xqxp := g2.Ext2.Sub(&q.X, &p.X) + λ1 := g2.Ext2.DivUnchecked(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := g2.Ext2.Square(λ1) + xqxp = g2.Ext2.Add(&p.X, &q.X) + x2 := g2.Ext2.Sub(λ1λ1, xqxp) + + // ommit y2 computation + // compute λ2 = -λ1-2*p.y/(x2-p.x) + ypyp := g2.Ext2.Add(&p.Y, &p.Y) + x2xp := g2.Ext2.Sub(x2, &p.X) + λ2 := g2.Ext2.DivUnchecked(ypyp, x2xp) + λ2 = g2.Ext2.Add(λ1, λ2) + λ2 = g2.Ext2.Neg(λ2) + + // compute x3 =λ2²-p.x-x3 + λ2λ2 := g2.Ext2.Square(λ2) + x3 := g2.Ext2.Sub(λ2λ2, &p.X) + x3 = g2.Ext2.Sub(x3, x2) + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := g2.Ext2.Sub(&p.X, x3) + y3 = g2.Ext2.Mul(λ2, y3) + y3 = g2.Ext2.Sub(y3, &p.Y) + + return &G2Affine{ + X: *x3, + Y: *y3, + } +} + +// AssertIsEqual asserts that p and q are the same point. +func (g2 *G2) AssertIsEqual(p, q *G2Affine) { + g2.Ext2.AssertIsEqual(&p.X, &q.X) + g2.Ext2.AssertIsEqual(&p.Y, &q.Y) +} diff --git a/std/algebra/emulated/sw_bls12381/g2_test.go b/std/algebra/emulated/sw_bls12381/g2_test.go new file mode 100644 index 0000000000..9d4a90d0e4 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/g2_test.go @@ -0,0 +1,120 @@ +package sw_bls12381 + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type addG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *addG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.add(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bls12381.G2Affine + res.Add(&in1, &in2) + witness := addG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleG2Circuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *doubleG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.double(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bls12381.G2Affine + var in1Jac, resJac bls12381.G2Jac + in1Jac.FromAffine(&in1) + resJac.Double(&in1Jac) + res.FromJacobian(&resJac) + witness := doubleG2Circuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleAndAddG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *doubleAndAddG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.doubleAndAdd(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleAndAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bls12381.G2Affine + res.Double(&in1). + Add(&res, &in2) + witness := doubleAndAddG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleAndAddG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type scalarMulG2BySeedCircuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *scalarMulG2BySeedCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.scalarMulBySeed(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestScalarMulG2BySeedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bls12381.G2Affine + x0, _ := new(big.Int).SetString("15132376222941642752", 10) + res.ScalarMultiplication(&in1, x0).Neg(&res) + witness := scalarMulG2BySeedCircuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2BySeedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/sw_bls12381/pairing.go b/std/algebra/emulated/sw_bls12381/pairing.go new file mode 100644 index 0000000000..72331dca92 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/pairing.go @@ -0,0 +1,837 @@ +package sw_bls12381 + +import ( + "errors" + "fmt" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type Pairing struct { + api frontend.API + *fields_bls12381.Ext12 + curveF *emulated.Field[emulated.BLS12381Fp] + g2 *G2 + g1 *G1 + curve *sw_emulated.Curve[emulated.BLS12381Fp, emulated.BLS12381Fr] + bTwist *fields_bls12381.E2 + lines [4][63]fields_bls12381.E2 +} + +type GTEl = fields_bls12381.E12 + +func NewGTEl(v bls12381.GT) GTEl { + return GTEl{ + C0: fields_bls12381.E6{ + B0: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B0.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B0.A1), + }, + B1: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B1.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B1.A1), + }, + B2: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B2.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B2.A1), + }, + }, + C1: fields_bls12381.E6{ + B0: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B0.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B0.A1), + }, + B1: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B1.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B1.A1), + }, + B2: fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B2.A0), + A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B2.A1), + }, + }, + } +} + +func NewPairing(api frontend.API) (*Pairing, error) { + ba, err := emulated.NewField[emulated.BLS12381Fp](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + curve, err := sw_emulated.New[emulated.BLS12381Fp, emulated.BLS12381Fr](api, sw_emulated.GetBLS12381Params()) + if err != nil { + return nil, fmt.Errorf("new curve: %w", err) + } + bTwist := fields_bls12381.E2{ + A0: emulated.ValueOf[emulated.BLS12381Fp]("4"), + A1: emulated.ValueOf[emulated.BLS12381Fp]("4"), + } + g1, err := NewG1(api) + if err != nil { + return nil, fmt.Errorf("new G1 struct: %w", err) + } + return &Pairing{ + api: api, + Ext12: fields_bls12381.NewExt12(api), + curveF: ba, + curve: curve, + g1: g1, + g2: NewG2(api), + bTwist: &bTwist, + lines: getPrecomputedLines(), + }, nil +} + +// FinalExponentiation computes the exponentiation (∏ᵢ zᵢ)ᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// we use instead +// +// d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// where s is the cofactor 3 (Hayashida et al.). +// +// This is the safe version of the method where e may be {-1,1}. If it is known +// that e ≠ {-1,1} then using the unsafe version of the method saves +// considerable amount of constraints. When called with the result of +// [MillerLoop], then current method is applicable when length of the inputs to +// Miller loop is 1. +func (pr Pairing) FinalExponentiation(e *GTEl) *GTEl { + return pr.finalExponentiation(e, false) +} + +// FinalExponentiationUnsafe computes the exponentiation (∏ᵢ zᵢ)ᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// we use instead +// +// d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// where s is the cofactor 3 (Hayashida et al.). +// +// This is the unsafe version of the method where e may NOT be {-1,1}. If e ∈ +// {-1, 1}, then there exists no valid solution to the circuit. This method is +// applicable when called with the result of [MillerLoop] method when the length +// of the inputs to Miller loop is 1. +func (pr Pairing) FinalExponentiationUnsafe(e *GTEl) *GTEl { + return pr.finalExponentiation(e, true) +} + +// finalExponentiation computes the exponentiation (∏ᵢ zᵢ)ᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// we use instead +// +// d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// +// where s is the cofactor 3 (Hayashida et al.). +func (pr Pairing) finalExponentiation(e *GTEl, unsafe bool) *GTEl { + + // 1. Easy part + // (p⁶-1)(p²+1) + var selector1, selector2 frontend.Variable + _dummy := pr.Ext6.One() + + if unsafe { + // The Miller loop result is ≠ {-1,1}, otherwise this means P and Q are + // linearly dependant and not from G1 and G2 respectively. + // So e ∈ G_{q,2} \ {-1,1} and hence e.C1 ≠ 0. + // Nothing to do. + } else { + // However, for a product of Miller loops (n>=2) this might happen. If this is + // the case, the result is 1 in the torus. We assign a dummy value (1) to e.C1 + // and proceed further. + selector1 = pr.Ext6.IsZero(&e.C1) + e.C1 = *pr.Ext6.Select(selector1, _dummy, &e.C1) + } + + // Torus compression absorbed: + // Raising e to (p⁶-1) is + // e^(p⁶) / e = (e.C0 - w*e.C1) / (e.C0 + w*e.C1) + // = (-e.C0/e.C1 + w) / (-e.C0/e.C1 - w) + // So the fraction -e.C0/e.C1 is already in the torus. + // This absorbs the torus compression in the easy part. + c := pr.Ext6.DivUnchecked(&e.C0, &e.C1) + c = pr.Ext6.Neg(c) + t0 := pr.FrobeniusSquareTorus(c) + c = pr.MulTorus(t0, c) + + // 2. Hard part (up to permutation) + // 3(p⁴-p²+1)/r + // Daiki Hayashida, Kenichiro Hayasaka and Tadanori Teruya + // https://eprint.iacr.org/2020/875.pdf + // performed in torus compressed form + t0 = pr.SquareTorus(c) + t1 := pr.ExptHalfTorus(t0) + t2 := pr.InverseTorus(c) + t1 = pr.MulTorus(t1, t2) + t2 = pr.ExptTorus(t1) + t1 = pr.InverseTorus(t1) + t1 = pr.MulTorus(t1, t2) + t2 = pr.ExptTorus(t1) + t1 = pr.FrobeniusTorus(t1) + t1 = pr.MulTorus(t1, t2) + c = pr.MulTorus(c, t0) + t0 = pr.ExptTorus(t1) + t2 = pr.ExptTorus(t0) + t0 = pr.FrobeniusSquareTorus(t1) + t1 = pr.InverseTorus(t1) + t1 = pr.MulTorus(t1, t2) + t1 = pr.MulTorus(t1, t0) + + var result GTEl + // MulTorus(c, t1) requires c ≠ -t1. When c = -t1, it means the + // product is 1 in the torus. + if unsafe { + // For a single pairing, this does not happen because the pairing is non-degenerate. + result = *pr.DecompressTorus(pr.MulTorus(c, t1)) + } else { + // For a product of pairings this might happen when the result is expected to be 1. + // We assign a dummy value (1) to t1 and proceed furhter. + // Finally we do a select on both edge cases: + // - Only if seletor1=0 and selector2=0, we return MulTorus(c, t1) decompressed. + // - Otherwise, we return 1. + _sum := pr.Ext6.Add(c, t1) + selector2 = pr.Ext6.IsZero(_sum) + t1 = pr.Ext6.Select(selector2, _dummy, t1) + selector := pr.api.Mul(pr.api.Sub(1, selector1), pr.api.Sub(1, selector2)) + result = *pr.Select(selector, pr.DecompressTorus(pr.MulTorus(c, t1)), pr.One()) + } + + return &result +} + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 - R0(x/y) - R1(1/y) = 0 instead of R0'*y - R1'*x - R2' = 0 This +// makes the multiplication by lines (MulBy014) and between lines (Mul014By014) +// circuit-efficient. +type lineEvaluation struct { + R0, R1 fields_bls12381.E2 +} + +// Pair calculates the reduced pairing for a set of points +// ∏ᵢ e(Pᵢ, Qᵢ). +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) Pair(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + res, err := pr.MillerLoop(P, Q) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.finalExponentiation(res, len(P) == 1) + return res, nil +} + +// PairingCheck calculates the reduced pairing for a set of points and asserts if the result is One +// ∏ᵢ e(Pᵢ, Qᵢ) =? 1 +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { + f, err := pr.Pair(P, Q) + if err != nil { + return err + + } + one := pr.One() + pr.AssertIsEqual(f, one) + + return nil +} + +func (pr Pairing) AssertIsEqual(x, y *GTEl) { + pr.Ext12.AssertIsEqual(x, y) +} + +func (pr Pairing) AssertIsOnCurve(P *G1Affine) { + pr.curve.AssertIsOnCurve(P) +} + +func (pr Pairing) AssertIsOnTwist(Q *G2Affine) { + // Twist: Y² == X³ + aX + b, where a=0 and b=4(1+u) + // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) + + // if Q=(0,0) we assign b=0 otherwise 3/(9+u), and continue + selector := pr.api.And(pr.Ext2.IsZero(&Q.X), pr.Ext2.IsZero(&Q.Y)) + + b := pr.Ext2.Select(selector, pr.Ext2.Zero(), pr.bTwist) + + left := pr.Ext2.Square(&Q.Y) + right := pr.Ext2.Square(&Q.X) + right = pr.Ext2.Mul(right, &Q.X) + right = pr.Ext2.Add(right, b) + pr.Ext2.AssertIsEqual(left, right) +} + +func (pr Pairing) AssertIsOnG1(P *G1Affine) { + // 1- Check P is on the curve + pr.AssertIsOnCurve(P) + + // 2- Check P has the right subgroup order + // TODO: add phi and scalarMulBySeedSquare to g1.go + // [x²]ϕ(P) + phiP := pr.g1.phi(P) + seedSquare := emulated.ValueOf[emulated.BLS12381Fr]("228988810152649578064853576960394133504") + // TODO: use addchain to construct a fixed-scalar ScalarMul + _P := pr.curve.ScalarMul(phiP, &seedSquare) + _P = pr.curve.Neg(_P) + + // [r]Q == 0 <==> P = -[x²]ϕ(P) + pr.curve.AssertIsEqual(_P, P) +} + +func (pr Pairing) AssertIsOnG2(Q *G2Affine) { + // 1- Check Q is on the curve + pr.AssertIsOnTwist(Q) + + // 2- Check Q has the right subgroup order + // [x₀]Q + xQ := pr.g2.scalarMulBySeed(Q) + // ψ(Q) + psiQ := pr.g2.psi(Q) + + // [r]Q == 0 <==> ψ(Q) == [x₀]Q + pr.g2.AssertIsEqual(xQ, psiQ) +} + +// loopCounter = seed in binary +// +// seed=-15132376222941642752 +var loopCounter = [64]int8{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 1, 0, 1, 1, +} + +// MillerLoop computes the multi-Miller loop +// ∏ᵢ { fᵢ_{u,Q}(P) } +func (pr Pairing) MillerLoop(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return nil, errors.New("invalid inputs sizes") + } + + res := pr.Ext12.One() + + var l1, l2 *lineEvaluation + Qacc := make([]*G2Affine, n) + yInv := make([]*emulated.Element[emulated.BLS12381Fp], n) + xOverY := make([]*emulated.Element[emulated.BLS12381Fp], n) + + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + // P and Q are supposed to be on G1 and G2 respectively of prime order r. + // The point (x,0) is of order 2. But this function does not check + // subgroup membership. + // Anyway (x,0) cannot be on BLS12-381 because -4 is a cubic non-residue in Fp. + // so, 1/y is well defined for all points P's + yInv[k] = pr.curveF.Inverse(&P[k].Y) + xOverY[k] = pr.curveF.MulMod(&P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + + // i = 62, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + + // Qacc[k] ← 3Qacc[k], + // l1 the tangent ℓ to 2Q[k] + // l2 the line ℓ passing 2Q[k] and Q[k] + Qacc[0], l1, l2 = pr.tripleStep(Qacc[0]) + // line evaluation at P[0] + // and assign line to res (R1, R0, 0, 0, 1, 0) + res.C0.B1 = *pr.MulByElement(&l1.R0, xOverY[0]) + res.C0.B0 = *pr.MulByElement(&l1.R1, yInv[0]) + res.C1.B1 = *pr.Ext2.One() + // line evaluation at P[0] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[0]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[0]) + // res = ℓ × ℓ + prodLines := *pr.Mul014By014(&l2.R1, &l2.R0, &res.C0.B0, &res.C0.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B1 = prodLines[3] + res.C1.B2 = prodLines[4] + + for k := 1; k < n; k++ { + // Qacc[k] ← 3Qacc[k], + // l1 the tangent ℓ to 2Q[k] + // l2 the line ℓ passing 2Q[k] and Q[k] + Qacc[k], l1, l2 = pr.tripleStep(Qacc[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + // ℓ × ℓ + prodLines = *pr.Mul014By014(&l1.R1, &l1.R0, &l2.R1, &l2.R0) + // (ℓ × ℓ) × res + res = pr.MulBy01245(res, &prodLines) + + } + + // Compute ∏ᵢ { fᵢ_{u,Q}(P) } + for i := 61; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + if loopCounter[i] == 0 { + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = pr.doubleStep(Qacc[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + } + } else { + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = pr.doubleAndAddStep(Qacc[k], Q[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + // ℓ × ℓ + prodLines = *pr.Mul014By014(&l1.R1, &l1.R0, &l2.R1, &l2.R0) + // (ℓ × ℓ) × res + res = pr.MulBy01245(res, &prodLines) + } + } + } + + // i = 0, separately to avoid a point doubling + res = pr.Square(res) + for k := 0; k < n; k++ { + // l1 the tangent ℓ passing 2Qacc[k] + l1 = pr.tangentCompute(Qacc[k]) + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + } + + // negative x₀ + res = pr.Ext12.Conjugate(res) + + return res, nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleAndAddStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation, *lineEvaluation) { + + var line1, line2 lineEvaluation + var p G2Affine + + // compute λ1 = (y2-y1)/(x2-x1) + n := pr.Ext2.Sub(&p1.Y, &p2.Y) + d := pr.Ext2.Sub(&p1.X, &p2.X) + l1 := pr.Ext2.DivUnchecked(n, d) + + // compute x3 =λ1²-x1-x2 + x3 := pr.Ext2.Square(l1) + x3 = pr.Ext2.Sub(x3, &p1.X) + x3 = pr.Ext2.Sub(x3, &p2.X) + + // omit y3 computation + + // compute line1 + line1.R0 = *pr.Ext2.Neg(l1) + line1.R1 = *pr.Ext2.Mul(l1, &p1.X) + line1.R1 = *pr.Ext2.Sub(&line1.R1, &p1.Y) + + // compute λ2 = -λ1-2y1/(x3-x1) + n = pr.Ext2.Double(&p1.Y) + d = pr.Ext2.Sub(x3, &p1.X) + l2 := pr.Ext2.DivUnchecked(n, d) + l2 = pr.Ext2.Add(l2, l1) + l2 = pr.Ext2.Neg(l2) + + // compute x4 = λ2²-x1-x3 + x4 := pr.Ext2.Square(l2) + x4 = pr.Ext2.Sub(x4, &p1.X) + x4 = pr.Ext2.Sub(x4, x3) + + // compute y4 = λ2(x1 - x4)-y1 + y4 := pr.Ext2.Sub(&p1.X, x4) + y4 = pr.Ext2.Mul(l2, y4) + y4 = pr.Ext2.Sub(y4, &p1.Y) + + p.X = *x4 + p.Y = *y4 + + // compute line2 + line2.R0 = *pr.Ext2.Neg(l2) + line2.R1 = *pr.Ext2.Mul(l2, &p1.X) + line2.R1 = *pr.Ext2.Sub(&line2.R1, &p1.Y) + + return &p, &line1, &line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleStep(p1 *G2Affine) (*G2Affine, *lineEvaluation) { + + var p G2Affine + var line lineEvaluation + + // λ = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ := pr.Ext2.DivUnchecked(n, d) + + // xr = λ²-2x + xr := pr.Ext2.Square(λ) + xr = pr.Ext2.Sub(xr, &p1.X) + xr = pr.Ext2.Sub(xr, &p1.X) + + // yr = λ(x-xr)-y + yr := pr.Ext2.Sub(&p1.X, xr) + yr = pr.Ext2.Mul(λ, yr) + yr = pr.Ext2.Sub(yr, &p1.Y) + + p.X = *xr + p.Y = *yr + + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &p, &line + +} + +// addStep adds two points in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) addStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation) { + + // compute λ = (y2-y1)/(x2-x1) + p2ypy := pr.Ext2.Sub(&p2.Y, &p1.Y) + p2xpx := pr.Ext2.Sub(&p2.X, &p1.X) + λ := pr.Ext2.DivUnchecked(p2ypy, p2xpx) + + // xr = λ²-x1-x2 + λλ := pr.Ext2.Square(λ) + p2xpx = pr.Ext2.Add(&p1.X, &p2.X) + xr := pr.Ext2.Sub(λλ, p2xpx) + + // yr = λ(x1-xr) - y1 + pxrx := pr.Ext2.Sub(&p1.X, xr) + λpxrx := pr.Ext2.Mul(λ, pxrx) + yr := pr.Ext2.Sub(λpxrx, &p1.Y) + + var res G2Affine + res.X = *xr + res.Y = *yr + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &res, &line + +} + +// tripleStep triples p1 in affine coordinates, and evaluates the line in Miller loop +func (pr Pairing) tripleStep(p1 *G2Affine) (*G2Affine, *lineEvaluation, *lineEvaluation) { + + var line1, line2 lineEvaluation + var res G2Affine + + // λ1 = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ1 := pr.Ext2.DivUnchecked(n, d) + + // compute line1 + line1.R0 = *pr.Ext2.Neg(λ1) + line1.R1 = *pr.Ext2.Mul(λ1, &p1.X) + line1.R1 = *pr.Ext2.Sub(&line1.R1, &p1.Y) + + // x2 = λ1²-2x + x2 := pr.Ext2.Square(λ1) + x2 = pr.Ext2.Sub(x2, &p1.X) + x2 = pr.Ext2.Sub(x2, &p1.X) + + // ommit yr computation, and + // compute λ2 = 2y/(x2 − x) − λ1. + x1x2 := pr.Ext2.Sub(&p1.X, x2) + λ2 := pr.Ext2.DivUnchecked(d, x1x2) + λ2 = pr.Ext2.Sub(λ2, λ1) + + // compute line2 + line2.R0 = *pr.Ext2.Neg(λ2) + line2.R1 = *pr.Ext2.Mul(λ2, &p1.X) + line2.R1 = *pr.Ext2.Sub(&line2.R1, &p1.Y) + + // xr = λ²-p.x-x2 + λ2λ2 := pr.Ext2.Mul(λ2, λ2) + qxrx := pr.Ext2.Add(x2, &p1.X) + xr := pr.Ext2.Sub(λ2λ2, qxrx) + + // yr = λ(p.x-xr) - p.y + pxrx := pr.Ext2.Sub(&p1.X, xr) + λ2pxrx := pr.Ext2.Mul(λ2, pxrx) + yr := pr.Ext2.Sub(λ2pxrx, &p1.Y) + + res.X = *xr + res.Y = *yr + + return &res, &line1, &line2 +} + +// tangentCompute computes the line that goes through p1 and p2 but does not compute p1+p2 +func (pr Pairing) tangentCompute(p1 *G2Affine) *lineEvaluation { + + // λ = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ := pr.Ext2.DivUnchecked(n, d) + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &line + +} + +// ---------------------------- +// Fixed-argument pairing +// ---------------------------- +// +// The second argument Q is g2 the fixed canonical generator of G2. +// +// g2.X.A0 = 0x24aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8 +// g2.X.A1 = 0x13e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e +// g2.Y.A0 = 0xce5d527727d6e118cc9cdc6da2e351aadfd9baa8cbdd3a76d429a695160d12c923ac9cc3baca289e193548608b82801 +// g2.Y.A1 = 0x606c4a02ea734cc32acd2b02bc28b99cb3e287e85a763af267492ab572e99ab3f370d275cec1da1aaa9075ff05f79be + +// MillerLoopFixed computes the single Miller loop +// fᵢ_{u,g2}(P), where g2 is fixed. +func (pr Pairing) MillerLoopFixedQ(P *G1Affine) (*GTEl, error) { + + res := pr.Ext12.One() + + var yInv, xOverY *emulated.Element[emulated.BLS12381Fp] + + // P and Q are supposed to be on G1 and G2 respectively of prime order r. + // The point (x,0) is of order 2. But this function does not check + // subgroup membership. + // Anyway (x,0) cannot be on BLS12-381 because -4 is a cubic non-residue in Fp. + // so, 1/y is well defined for all points P's + yInv = pr.curveF.Inverse(&P.Y) + xOverY = pr.curveF.MulMod(&P.X, yInv) + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + + // i = 62, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[1][62], yInv), + pr.MulByElement(&pr.lines[0][62], xOverY), + ) + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[3][62], yInv), + pr.MulByElement(&pr.lines[2][62], xOverY), + ) + + // Compute ∏ᵢ { fᵢ_{u,Q}(P) } + for i := 61; i >= 0; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + if loopCounter[i] == 0 { + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[1][i], yInv), + pr.MulByElement(&pr.lines[0][i], xOverY), + ) + } else { + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[1][i], yInv), + pr.MulByElement(&pr.lines[0][i], xOverY), + ) + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[3][i], yInv), + pr.MulByElement(&pr.lines[2][i], xOverY), + ) + } + } + + // negative x₀ + res = pr.Ext12.Conjugate(res) + + return res, nil +} + +// DoubleMillerLoopFixedQ computes the double Miller loop +// fᵢ_{u,g2}(T) * fᵢ_{u,Q}(P), where g2 is fixed. +func (pr Pairing) DoubleMillerLoopFixedQ(P, T *G1Affine, Q *G2Affine) (*GTEl, error) { + res := pr.Ext12.One() + + var l1, l2 *lineEvaluation + var Qacc *G2Affine + Qacc = Q + var yInv, xOverY, y2Inv, x2OverY2 *emulated.Element[emulated.BLS12381Fp] + yInv = pr.curveF.Inverse(&P.Y) + xOverY = pr.curveF.MulMod(&P.X, yInv) + y2Inv = pr.curveF.Inverse(&T.Y) + x2OverY2 = pr.curveF.MulMod(&T.X, y2Inv) + + // i = 62, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // Qacc ← 3Qacc, + // l1 the tangent ℓ to 2Q + // l2 the line ℓ passing 2Q and Q + Qacc, l1, l2 = pr.tripleStep(Qacc) + // line evaluation at P + // and assign line to res (R1, R0, 0, 0, 1, 0) + res.C0.B1 = *pr.MulByElement(&l1.R0, xOverY) + res.C0.B0 = *pr.MulByElement(&l1.R1, yInv) + res.C1.B1 = *pr.Ext2.One() + // line evaluation at P + l2.R0 = *pr.MulByElement(&l2.R0, xOverY) + l2.R1 = *pr.MulByElement(&l2.R1, yInv) + // res = ℓ × ℓ + prodLines := *pr.Mul014By014(&l2.R1, &l2.R0, &res.C0.B0, &res.C0.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B1 = prodLines[3] + res.C1.B2 = prodLines[4] + + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[1][62], y2Inv), + pr.MulByElement(&pr.lines[0][62], x2OverY2), + ) + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[3][62], y2Inv), + pr.MulByElement(&pr.lines[2][62], x2OverY2), + ) + + // Compute ∏ᵢ { fᵢ_{u,G2}(T) } + for i := 61; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + if loopCounter[i] == 0 { + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[1][i], y2Inv), + pr.MulByElement(&pr.lines[0][i], x2OverY2), + ) + // Qacc ← 2Qacc and l1 the tangent ℓ passing 2Qacc + Qacc, l1 = pr.doubleStep(Qacc) + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + // ℓ × res + res = pr.MulBy014(res, &l1.R1, &l1.R0) + } else { + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[1][i], y2Inv), + pr.MulByElement(&pr.lines[0][i], x2OverY2), + ) + res = pr.MulBy014(res, + pr.MulByElement(&pr.lines[3][i], y2Inv), + pr.MulByElement(&pr.lines[2][i], x2OverY2), + ) + // Qacc ← 2Qacc+Q, + // l1 the line ℓ passing Qacc and Q + // l2 the line ℓ passing (Qacc+Q) and Qacc + Qacc, l1, l2 = pr.doubleAndAddStep(Qacc, Q) + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + // line evaluation at P + l2.R0 = *pr.MulByElement(&l2.R0, xOverY) + l2.R1 = *pr.MulByElement(&l2.R1, yInv) + // ℓ × ℓ + prodLines = *pr.Mul014By014(&l1.R1, &l1.R0, &l2.R1, &l2.R0) + // (ℓ × ℓ) × res + res = pr.MulBy01245(res, &prodLines) + + } + } + + // i = 0, separately to avoid a point doubling + res = pr.Square(res) + // l1 the tangent ℓ passing 2Qacc + l1 = pr.tangentCompute(Qacc) + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + // ℓ × ℓ + prodLines = *pr.Mul014By014( + &l1.R1, + &l1.R0, + pr.MulByElement(&pr.lines[1][0], y2Inv), + pr.MulByElement(&pr.lines[0][0], x2OverY2), + ) + // (ℓ × ℓ) × res + res = pr.MulBy01245(res, &prodLines) + + // negative x₀ + res = pr.Ext12.Conjugate(res) + + return res, nil +} + +// PairFixedQ calculates the reduced pairing for a set of points +// e(P, g2), where g2 is fixed. +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) PairFixedQ(P *G1Affine) (*GTEl, error) { + res, err := pr.MillerLoopFixedQ(P) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.finalExponentiation(res, true) + return res, nil +} + +// DoublePairFixedQ calculates the reduced pairing for a set of points +// e(P, Q) * e(T, g2), where g2 is fixed. +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) DoublePairFixedQ(P, T *G1Affine, Q *G2Affine) (*GTEl, error) { + res, err := pr.DoubleMillerLoopFixedQ(P, T, Q) + if err != nil { + return nil, fmt.Errorf("double miller loop: %w", err) + } + res = pr.finalExponentiation(res, false) + return res, nil +} diff --git a/std/algebra/emulated/sw_bls12381/pairing_test.go b/std/algebra/emulated/sw_bls12381/pairing_test.go new file mode 100644 index 0000000000..39feaa30f7 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/pairing_test.go @@ -0,0 +1,382 @@ +package sw_bls12381 + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +func randomG1G2Affines() (bls12381.G1Affine, bls12381.G2Affine) { + _, _, G1AffGen, G2AffGen := bls12381.Generators() + mod := bls12381.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + var p bls12381.G1Affine + p.ScalarMultiplication(&G1AffGen, s1) + var q bls12381.G2Affine + q.ScalarMultiplication(&G2AffGen, s2) + return p, q +} + +type FinalExponentiationCircuit struct { + InGt GTEl + Res GTEl +} + +func (c *FinalExponentiationCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res1 := pairing.FinalExponentiation(&c.InGt) + pairing.AssertIsEqual(res1, &c.Res) + res2 := pairing.FinalExponentiationUnsafe(&c.InGt) + pairing.AssertIsEqual(res2, &c.Res) + return nil +} + +func TestFinalExponentiationTestSolve(t *testing.T) { + assert := test.NewAssert(t) + var gt bls12381.GT + gt.SetRandom() + res := bls12381.FinalExponentiation(>) + witness := FinalExponentiationCircuit{ + InGt: NewGTEl(gt), + Res: NewGTEl(res), + } + err := test.IsSolved(&FinalExponentiationCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairCircuit struct { + InG1 G1Affine + InG2 G2Affine + Res GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + res, err := pairing.Pair([]*G1Affine{&c.InG1}, []*G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + res, err := bls12381.Pair([]bls12381.G1Affine{p}, []bls12381.G2Affine{q}) + assert.NoError(err) + witness := PairCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + Res: NewGTEl(res), + } + err = test.IsSolved(&PairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type MultiPairCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine + Res GTEl +} + +func (c *MultiPairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.In1G1) + pairing.AssertIsOnG1(&c.In2G1) + pairing.AssertIsOnG2(&c.In1G2) + pairing.AssertIsOnG2(&c.In2G2) + res, err := pairing.Pair([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestMultiPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines() + p2, q2 := randomG1G2Affines() + res, err := bls12381.Pair([]bls12381.G1Affine{p1, p1, p2, p2}, []bls12381.G2Affine{q1, q2, q1, q2}) + assert.NoError(err) + witness := MultiPairCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + Res: NewGTEl(res), + } + err = test.IsSolved(&MultiPairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairingCheckCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine +} + +func (c *PairingCheckCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + return nil +} + +func TestPairingCheckTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines() + _, q2 := randomG1G2Affines() + var p2 bls12381.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + err := test.IsSolved(&PairingCheckCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type FinalExponentiationSafeCircuit struct { + P1, P2 G1Affine + Q1, Q2 G2Affine +} + +func (c *FinalExponentiationSafeCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return err + } + res, err := pairing.MillerLoop([]*G1Affine{&c.P1, &c.P2}, []*G2Affine{&c.Q1, &c.Q2}) + if err != nil { + return err + } + res2 := pairing.FinalExponentiation(res) + one := pairing.Ext12.One() + pairing.AssertIsEqual(one, res2) + return nil +} + +func TestFinalExponentiationSafeCircuit(t *testing.T) { + assert := test.NewAssert(t) + _, _, p1, q1 := bls12381.Generators() + var p2 bls12381.G1Affine + var q2 bls12381.G2Affine + p2.Neg(&p1) + q2.Set(&q1) + err := test.IsSolved(&FinalExponentiationSafeCircuit{}, &FinalExponentiationSafeCircuit{ + P1: NewG1Affine(p1), + P2: NewG1Affine(p2), + Q1: NewG2Affine(q1), + Q2: NewG2Affine(q2), + }, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type GroupMembershipCircuit struct { + InG1 G1Affine + InG2 G2Affine +} + +func (c *GroupMembershipCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + return nil +} + +func TestGroupMembershipSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + witness := GroupMembershipCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + } + err := test.IsSolved(&GroupMembershipCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +// ---------------------------- +// Fixed-argument pairing +// ---------------------------- +// +// The second argument Q is the fixed canonical generator of G2. +// +// Q.X.A0 = 0x24aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8 +// Q.X.A1 = 0x13e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e +// Q.Y.A0 = 0xce5d527727d6e118cc9cdc6da2e351aadfd9baa8cbdd3a76d429a695160d12c923ac9cc3baca289e193548608b82801 +// Q.Y.A1 = 0x606c4a02ea734cc32acd2b02bc28b99cb3e287e85a763af267492ab572e99ab3f370d275cec1da1aaa9075ff05f79be + +type PairFixedCircuit struct { + InG1 G1Affine + Res GTEl +} + +func (c *PairFixedCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res, err := pairing.PairFixedQ(&c.InG1) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPairFixedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, _ := randomG1G2Affines() + _, _, _, G2AffGen := bls12381.Generators() + res, err := bls12381.Pair([]bls12381.G1Affine{p}, []bls12381.G2Affine{G2AffGen}) + assert.NoError(err) + witness := PairFixedCircuit{ + InG1: NewG1Affine(p), + Res: NewGTEl(res), + } + err = test.IsSolved(&PairFixedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type DoublePairFixedCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + Res GTEl +} + +func (c *DoublePairFixedCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res, err := pairing.DoublePairFixedQ(&c.In1G1, &c.In2G1, &c.In1G2) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoublePairFixedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + _, _, _, G2AffGen := bls12381.Generators() + res, err := bls12381.Pair([]bls12381.G1Affine{p, p}, []bls12381.G2Affine{q, G2AffGen}) + assert.NoError(err) + witness := DoublePairFixedCircuit{ + In1G1: NewG1Affine(p), + In2G1: NewG1Affine(p), + In1G2: NewG2Affine(q), + Res: NewGTEl(res), + } + err = test.IsSolved(&DoublePairFixedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +// bench +func BenchmarkPairing(b *testing.B) { + + p1, q1 := randomG1G2Affines() + _, q2 := randomG1G2Affines() + var p2 bls12381.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + b.Fatal(err) + } + var ccs constraint.ConstraintSystem + b.Run("compile scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &PairingCheckCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + var buf bytes.Buffer + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("scs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + b.Run("solve scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + b.Run("compile r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &PairingCheckCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + buf.Reset() + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("r1cs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + + b.Run("solve r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/std/algebra/emulated/sw_bls12381/precomputations.go b/std/algebra/emulated/sw_bls12381/precomputations.go new file mode 100644 index 0000000000..2d3b45baf3 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/precomputations.go @@ -0,0 +1,368 @@ +package sw_bls12381 + +import ( + "sync" + + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/math/emulated" +) + +// precomputed lines going through Q and multiples of Q +// where Q is the fixed canonical generator of G2 +// +// Q.X.A0 = 0x24aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8 +// Q.X.A1 = 0x13e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e +// Q.Y.A0 = 0xce5d527727d6e118cc9cdc6da2e351aadfd9baa8cbdd3a76d429a695160d12c923ac9cc3baca289e193548608b82801 +// Q.Y.A1 = 0x606c4a02ea734cc32acd2b02bc28b99cb3e287e85a763af267492ab572e99ab3f370d275cec1da1aaa9075ff05f79be + +var precomputedLines [4][63]fields_bls12381.E2 +var precomputedLinesOnce sync.Once + +func getPrecomputedLines() [4][63]fields_bls12381.E2 { + precomputedLinesOnce.Do(func() { + precomputedLines = computePrecomputedLines() + }) + return precomputedLines +} + +func computePrecomputedLines() [4][63]fields_bls12381.E2 { + var PrecomputedLines [4][63]fields_bls12381.E2 + // i = 62 + PrecomputedLines[0][62].A0 = emulated.ValueOf[emulated.BLS12381Fp]("669548411974166141085976469385723866436472370505741175753098277111398794822779266854375792083840680157046782393346") + PrecomputedLines[0][62].A1 = emulated.ValueOf[emulated.BLS12381Fp]("485573966283411892033555951945153738534613704502175152712350069145458247051515827246727572080547567638993006152517") + PrecomputedLines[1][62].A0 = emulated.ValueOf[emulated.BLS12381Fp]("732314894289793391974712446827048498840404262185434961420778947287791823788298610205403980970134873973855179358473") + PrecomputedLines[1][62].A1 = emulated.ValueOf[emulated.BLS12381Fp]("812131749855875379265953416260136615241951347320592877965506695109494376177747525031326561409216449039455993646505") + PrecomputedLines[2][62].A0 = emulated.ValueOf[emulated.BLS12381Fp]("812813015191018354207694813555941619677019645169301374927303155179580064739571240876370786966001693001968716962848") + PrecomputedLines[2][62].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3422194989631851684366263045255187543169850434295111246095065735670714951846252356682387297924035567868374078778836") + PrecomputedLines[3][62].A0 = emulated.ValueOf[emulated.BLS12381Fp]("157456865461253934305228427450035810208837931422473369102772999374350410389521892324920360694666903623614196805494") + PrecomputedLines[3][62].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3411528111412942541370177329760558674332800117796166840462360536268624996289772462833524354828959884682972709883122") + // i = 61 + PrecomputedLines[0][61].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3167355389248931702681920381216519418222081367984652005711617948531982787239184999431930575874115723238533491899165") + PrecomputedLines[0][61].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3447267127263115595642074965719097743009312894047812306615936074422000773646907921701766406039366119066738894122723") + PrecomputedLines[1][61].A0 = emulated.ValueOf[emulated.BLS12381Fp]("688058980511165006458564368448205274272822292909762824504151377882157830132026052687556627274968720237714650160120") + PrecomputedLines[1][61].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3135587461047004958207254072262540591320988757795402058246794065635354508091637006911264790495996819610302037162717") + // i = 60 + PrecomputedLines[0][60].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1920944633235921513988518596278041022953356985015778043050763086272582601493666355019095852609650813854128544305181") + PrecomputedLines[0][60].A1 = emulated.ValueOf[emulated.BLS12381Fp]("860492157934473231829635744790033660958584147179311700394847602967819509512077660148051140335608956838465539095273") + PrecomputedLines[1][60].A0 = emulated.ValueOf[emulated.BLS12381Fp]("154735886397992143544806730614061370678971837934333975977429679584966012416707029831585725093568219539315233881659") + PrecomputedLines[1][60].A1 = emulated.ValueOf[emulated.BLS12381Fp]("376818648643968640397954810189258513612631887956420766739477434016962285151114935114922630691454914927141730374752") + PrecomputedLines[2][60].A0 = emulated.ValueOf[emulated.BLS12381Fp]("641527568955787391077163687908538469505081470399460441692301288737222490738061487415717406714694506016714777293130") + PrecomputedLines[2][60].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3744267661141807834055307161213218479864096716748396680737728943139485914722174229482692458759736989024915141720741") + PrecomputedLines[3][60].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1979277100806290341353932143559540821614946287031881379641593473632280348718200753415872393941708408115665834864981") + PrecomputedLines[3][60].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3524079793475292610139160896458047061630390153266629174975238784247652400355301532647529891858932671865719593746519") + // i = 59 + PrecomputedLines[0][59].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2872686259828813821901776510081648106707612949930121757477603035304773855306148547185864242118276240939276501471912") + PrecomputedLines[0][59].A1 = emulated.ValueOf[emulated.BLS12381Fp]("920463563137047419952256970631145852261809899713740156757053940295816615150445509541799135283673190039502817405204") + PrecomputedLines[1][59].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3229873723050285868589420275301455820861051245243344751135273165330871017410167055454147819278539515196847338970604") + PrecomputedLines[1][59].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3250551647670882224581247559778513579667585374140502287850693829340751460016290130863912315736341727657711566097737") + // i = 58 + PrecomputedLines[0][58].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3714619193743885419585832708131169264803874205331637930392693795744061038893458100694631895928325428608821091186201") + PrecomputedLines[0][58].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2803750130323180936787347666144632907990389635882238646661111781401631109359498707800960039585964112969347050693538") + PrecomputedLines[1][58].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1472897540326065072059132746588167179515249703686025418792853111589229339540168375181352277937001709294618169294768") + PrecomputedLines[1][58].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3135206684301227064375208594383629110896364928784129558817953479498838576564714827663870236430119199289098913046652") + // i = 57 + PrecomputedLines[0][57].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3809995718071284042151900973024488933103030049864031067071220848309199706820588472594891801557974102218497622700071") + PrecomputedLines[0][57].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3912165225431699058345901061393683854947363816159330562713945608955852170345560967680689194395447164868631283050372") + PrecomputedLines[1][57].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2612385026547803629760760529969977825484378330987722663235351236102353757267355059756029342891246063337681516513239") + PrecomputedLines[1][57].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3138408358796246446673337034537560252797484869470651011585286460917182785261415450828915245077625596387266880447102") + PrecomputedLines[2][57].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2710175849436261653096354516342437318921319381660945851344881244999005767888550326063765862016499488689860335247904") + PrecomputedLines[2][57].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3299315682062156295281559161247281711935174505100911525914443559341485746169616283628132512486441613155889807522683") + PrecomputedLines[3][57].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3885957143056127103039913083706920216817048993469158176611416475280944971176966597194215917398415610118443429941518") + PrecomputedLines[3][57].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1172972047837757111565152162191536831837373537643921973246599500915199390098182165856110670906629269605489393487725") + // i = 57 + PrecomputedLines[0][56].A0 = emulated.ValueOf[emulated.BLS12381Fp]("622347879759205768011022247468749680855686299123085419284274056276134429911385606501233967974766218644366424518736") + PrecomputedLines[0][56].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1500740749189462026269065089655113278077195444343693251448994052320674556195186935156668341304929570114492129622057") + PrecomputedLines[1][56].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2795735917000119332024849594147759966503120883167292404312386146916548928506191157798369716404301202392840237584010") + PrecomputedLines[1][56].A1 = emulated.ValueOf[emulated.BLS12381Fp]("574044548578100921001269599195979818060436038680486863462612989496313429331211929276265366750107217926566428464671") + // i = 56 + PrecomputedLines[0][55].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1683807955380296529179464277356150583410685997285060844955332063027770486188723773848801053602792684661825764852985") + PrecomputedLines[0][55].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2719558810791192088850367207927359344511225919945451359115915371928444741396004265400249643841105922892835790005551") + PrecomputedLines[1][55].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2969634126445395157111080166389580842022545720354905778776357619927566473203993743030834911454836656475683368211894") + PrecomputedLines[1][55].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3170175285127781392809786438941043126306291403685156491643424647115525372973022802000864505005625723904690328063907") + // i = 54 + PrecomputedLines[0][54].A0 = emulated.ValueOf[emulated.BLS12381Fp]("375779067730788644310925153474821811448285005848537210975950783963851371044263319992105676186250155417651098156170") + PrecomputedLines[0][54].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3733966523158115416788023430991316807003829120352154621710514483269358061039041785724604609643071428687478543399533") + PrecomputedLines[1][54].A0 = emulated.ValueOf[emulated.BLS12381Fp]("750499460814355297922927817774908099053943604885008295706164614840446320259237167343783111403152276091379252684616") + PrecomputedLines[1][54].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2825318853533025835349605808545710541034664742563654571824138682643441922126128300908285862303401700954453127481714") + // i = 53 + PrecomputedLines[0][53].A0 = emulated.ValueOf[emulated.BLS12381Fp]("627679888100276358357176940007265575606442637514150679111460077647327322461845493796232024954906194988731412742990") + PrecomputedLines[0][53].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1640474007465603807593221117315813768631350077835933129371268729206573832387041087097824928540207001461881471678365") + PrecomputedLines[1][53].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3023647540724840221448340194237062576117700682390969333781535156475098745856085799973057471235608525769873178028191") + PrecomputedLines[1][53].A1 = emulated.ValueOf[emulated.BLS12381Fp]("486526745657630431509941094224758430578307862719388319027652292841508433960493074790673321073294175358344563135416") + // i = 52 + PrecomputedLines[0][52].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3650296230206170331461948170412665436446748701218003831277887253654064277753511532962960553265563620360134226064014") + PrecomputedLines[0][52].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2409561215938887896965536259947303617818362458239899554408099496657146779228013561109106195165888510812287420438371") + PrecomputedLines[1][52].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2519665618777597708805426117626873044241626169780728743913542033986762322452960126595474642740412111194932498442424") + PrecomputedLines[1][52].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2995288065509063655204878433593032419361001555657280878228932810763402933565472618640755136903486140985217860614430") + // i = 51 + PrecomputedLines[0][51].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3965943146012523012436886123935901259641842878949589183385440733070650071752022700848687298008706652773189789571156") + PrecomputedLines[0][51].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3781587587887830119678014027926478229242655143168314612514566854608758381016640445590396766447463238149599614345643") + PrecomputedLines[1][51].A0 = emulated.ValueOf[emulated.BLS12381Fp]("281493173741755688481134681014554614648591125791327736577228157760720931895294993775883808637897570982349481278455") + PrecomputedLines[1][51].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3993652227334123342020529370134577009762551708442580124648039753538411625652437681601773709227840139276954956023720") + // i = 50 + PrecomputedLines[0][50].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2516524990890051114557830172624952846622696966741855268858924752689012638050134098337291599300347049762631249499009") + PrecomputedLines[0][50].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2838277724440483542250261902488325609796685855313384724751745262627242917164815168643675221659225278029646010530898") + PrecomputedLines[1][50].A0 = emulated.ValueOf[emulated.BLS12381Fp]("205814431017841781529711167473014574588315078838308781682081636487986677792974415834575517630013694964226548873843") + PrecomputedLines[1][50].A1 = emulated.ValueOf[emulated.BLS12381Fp]("457345270056658240314731954295238544045524550279216572405272871385901601878145910922764232048405028749348986319946") + PrecomputedLines[0][50].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2516524990890051114557830172624952846622696966741855268858924752689012638050134098337291599300347049762631249499009") + // i = 49 + PrecomputedLines[0][49].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1020346919919454689345686298025664492420578570709403850078404276953749437271628605608591033953577550721312654894751") + PrecomputedLines[0][49].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2035233040573767290854203089478489933625643223841740794072513917724855861269094032057797894488383395594601648501710") + PrecomputedLines[1][49].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1327235758160253855593076914415418147974794848908239888587061087772903111782401424731516402162647994151356947511796") + PrecomputedLines[1][49].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2345983157885705725975284492081588828805085626463839465716669020430923654178549024685469937397030677495963662130017") + // i = 48 + PrecomputedLines[0][48].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3999794500002364705194387996540043058996671076749672420994976952879632456695054734701262124117567376514305999316548") + PrecomputedLines[0][48].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2037523712914198976262251711094301512285508453137565038665483625864472133289646125530201428506022430707543797731060") + PrecomputedLines[1][48].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1634161835818922950487208355001947642560543372116279175990598010448156542079631096605929066390913044814879119808353") + PrecomputedLines[1][48].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2554166153844771975465431059358548130366205693353150926329014635702290148449003780762406943828228518337363360267088") + PrecomputedLines[2][48].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2774236537775390183063657058358755572488017173035729341596763163103225009018936586885492696339030774916510005182013") + PrecomputedLines[2][48].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1173654495128890249932151437147000173891141885237377296676843456320239874043343036348838009942995118588197235041811") + PrecomputedLines[3][48].A0 = emulated.ValueOf[emulated.BLS12381Fp]("319716182789485809633400119894079442154863170302317662942572724308602882847075525211093447184144874467659787199054") + PrecomputedLines[3][48].A1 = emulated.ValueOf[emulated.BLS12381Fp]("665774523119220953421023874996858698673633880237440435786173568687987025043665032919009942945066009450203009618903") + // i = 47 + PrecomputedLines[0][47].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1352228522597441161137908156105710498714604453933397629408772872193339853796874615484054546448392540809392766374868") + PrecomputedLines[0][47].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1672392011071112149342136031100062933670641454779307380594553264807703694958126664626067442230569492865463262059558") + PrecomputedLines[1][47].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1599374044309101711396661090243685761282731445460192675445448884679532421917337092819806700750615911921798461360121") + PrecomputedLines[1][47].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2512650952901556745731974003978943420505200981672715597095524880689593099413101445780656721815785475053889081706678") + // i = 46 + PrecomputedLines[0][46].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3471020969331868339274745424267673875484897140558715059509458909017470303217040695247408058206798726472147178123123") + PrecomputedLines[0][46].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2819047325620898479367544526754144985358723664205955650131575490077399551405667341723146320896830720801098582857709") + PrecomputedLines[1][46].A0 = emulated.ValueOf[emulated.BLS12381Fp]("630925931142976074685573795766009010345551824202186317510723152511213269382926514663889003322958192458264917883739") + PrecomputedLines[1][46].A1 = emulated.ValueOf[emulated.BLS12381Fp]("730708946329622469595393173442148447611685401673527273593418546020927684409753133837520124894637951916365839386160") + // i = 45 + PrecomputedLines[0][45].A0 = emulated.ValueOf[emulated.BLS12381Fp]("198412400636007802811072148882119615622478984710784674204538043831911964641933587009134635150884704808699172988667") + PrecomputedLines[0][45].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1235538121757394673935269190943945445103307027490144781712140275173202783709742505176430745360811309204020795810303") + PrecomputedLines[1][45].A0 = emulated.ValueOf[emulated.BLS12381Fp]("127449723165767618160544163158081887167750270799206910156801724420258731910725471222101164978860927062135819778920") + PrecomputedLines[1][45].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1980335922546437651709751133736030207175349461879996979062693453362277374709777569183416036000958323716762049291744") + // i = 44 + PrecomputedLines[0][44].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3542992614482057236559067235605651529999889272871178993863126553661986885573156923893097305492580034291218414473014") + PrecomputedLines[0][44].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3584527953490094615607685914397103924239362747279962542688598060206411425604001638641737039186638967633649455194291") + PrecomputedLines[1][44].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2933771789553990042831343608346228744924936740235468281265740746084993415654144367825141102522105790771732605662673") + PrecomputedLines[1][44].A1 = emulated.ValueOf[emulated.BLS12381Fp]("357221450466112351467892138625473814367000466631434933520256036228321581439111528221088566194593092448380607628937") + // i = 43 + PrecomputedLines[0][43].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1428505519858414724963788585080455020779142968543396706742348738430903946057154444325924558405113243287649847547028") + PrecomputedLines[0][43].A1 = emulated.ValueOf[emulated.BLS12381Fp]("255469568805094053943811439244838133421341964456430345396115388192792360855759659375210457840330885549905265166174") + PrecomputedLines[1][43].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2841916087237451089871919956880030230404765495311921549225109128734951911064362232331455183143163787624459526919204") + PrecomputedLines[1][43].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2080277921400703811750799004531700060990459649936071284451991088891151958553636754908447208794532902985981717385197") + // i = 42 + PrecomputedLines[0][42].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1161813519543100832448485023726218217911506899980751063568982100444610240283155458678085859357761886310777271380218") + PrecomputedLines[0][42].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2024543557518754018225240025183007210142579342901765863468909024213131985052250019822553016575047190637324971172371") + PrecomputedLines[1][42].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1300348585886929751726227943801215133444838387886469335147731950649896068046555404665420122667116820807544395572991") + PrecomputedLines[1][42].A1 = emulated.ValueOf[emulated.BLS12381Fp]("132553990561974819643301422280568738175710696265452865499319629218595397851655434827834869071879254702329949577177") + // i = 41 + PrecomputedLines[0][41].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3663938147630410348590845529884105686512799097554693485461263550839609650459905357149747443181106900675222183621502") + PrecomputedLines[0][41].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1167353691883536575912237192065757084296471281152540061033126407319851295132511138053309418864479978963783870392046") + PrecomputedLines[1][41].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2502366363958385636575096731742781445822573495751060987471063288395705513008577494849055381410929299192520021143963") + PrecomputedLines[1][41].A1 = emulated.ValueOf[emulated.BLS12381Fp]("168012017786556155916365424936030459963969329890011871555391199195524501359161804719804475541392820589754269324473") + // i = 40 + PrecomputedLines[0][40].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3422147736366383751754174069648796681157726214083354570593473083443425258197833570132918683247736376295040893339583") + PrecomputedLines[0][40].A1 = emulated.ValueOf[emulated.BLS12381Fp]("11337927206202880824871358790200052804337474988256558556739237234008272644532973070128389538534846278251601901582") + PrecomputedLines[1][40].A0 = emulated.ValueOf[emulated.BLS12381Fp]("937059775967306522240360587341180755614199042950625504697937131942983650460054919541269959395433778244986419803777") + PrecomputedLines[1][40].A1 = emulated.ValueOf[emulated.BLS12381Fp]("452377712017692674398905129564653333739311305247415895016831258644194298528119821127996385402381262854234490259250") + // i = 39 + PrecomputedLines[0][39].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3612997970192046360277946172414595493705800672421085458639618548707813391029436946048064626238111588216813929361632") + PrecomputedLines[0][39].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1878416626790773315050527185017103501757993372372008108763726391343181754684270883343531865692543923966766615449574") + PrecomputedLines[1][39].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1064399826825526943533915590474819784350899403610383371512627205253691058435433480324428528977455879008913187859489") + PrecomputedLines[1][39].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1493858117017686382025018246020080696784248677063660248506323896583625254078001520559429449267479753774555640656115") + // i = 38 + PrecomputedLines[0][38].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1735664856890207310169253204399524178813326485311425355248369023944857598179113751774794101343852837913801486402934") + PrecomputedLines[0][38].A1 = emulated.ValueOf[emulated.BLS12381Fp]("257987757689159452904715641284385152828467715986517600177681521531352535362174538747360373817060560772313858089722") + PrecomputedLines[1][38].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3946465293088454781014388324927134871913047898475854876800085625862621433529365655612871749364409826568968642462682") + PrecomputedLines[1][38].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3514516048833750160654616198372533402722421230509993177084104711179394585961253245647565370709206323382607579115108") + // i = 37 + PrecomputedLines[0][37].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1401520457984247779943479808418140600316441439198119727504755495446774329464642961053357870501084879285839953943267") + PrecomputedLines[0][37].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3801631184852487334845090396139279606066156434973587251458920274531296829855683500529675477873800149183554843368547") + PrecomputedLines[1][37].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3793843723599766510128907097460302748913564383405034491079209933664814101598005798700867405427941067839702387132551") + PrecomputedLines[1][37].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3306940590973833542721039863858362748858705070089020543718304082366722282663705234043233941444934273726393427156475") + // i = 36 + PrecomputedLines[0][36].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1386206426371721430883261200223512381487749770457988721883157746964420160482463384316863927260168499246730796790725") + PrecomputedLines[0][36].A1 = emulated.ValueOf[emulated.BLS12381Fp]("105258678333128666521354740403587795176240437377501517632246986958341935429854720947092812955927086837476976184408") + PrecomputedLines[1][36].A0 = emulated.ValueOf[emulated.BLS12381Fp]("932703267338824377065600514218801996331026291359027400436063614448672472594233465820977170328221080390337491527475") + PrecomputedLines[1][36].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1997691216719219792140629951813670927631278876322624455128068353052974430392275280843680022909857531980837305445355") + // i = 35 + PrecomputedLines[0][35].A0 = emulated.ValueOf[emulated.BLS12381Fp]("643376698540159636558778051726667023044743522340636476360479502393735290090800846463555049440499660661597709502815") + PrecomputedLines[0][35].A1 = emulated.ValueOf[emulated.BLS12381Fp]("273152130480947176848926276607476536605977478035570463953153325427234725841518087669121533072492261967650512345505") + PrecomputedLines[1][35].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2884756623806302783060959589833895994361358453330946927482503242534883207284024199766671460228049002959968955556544") + PrecomputedLines[1][35].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1104408845625748157498570748739992116575830335623545342349641509468821031457722373198091225304845345864484821816520") + // i = 34 + PrecomputedLines[0][34].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2355394091472044936279825599359959938355474687645308765411687460086621236866142452518153528913315194632130423391438") + PrecomputedLines[0][34].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1514545828569599419557358915925562383958111322715051805092136729868625115938669404676383501473562085370690718143393") + PrecomputedLines[1][34].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2342810081178981235002861700946741808510377823228526922452112941039257271124158122048943685128217560845825533768234") + PrecomputedLines[1][34].A1 = emulated.ValueOf[emulated.BLS12381Fp]("503445239041840170015602950815430394154806858030319353226841813139639784228695779691934743598652740721050230266959") + // i = 33 + PrecomputedLines[0][33].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2657084885954644398746184778816531794363952823725527101620555492372505275656368937266127716153751518258248254165182") + PrecomputedLines[0][33].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2950867322457899876246688634040402014645144799326166449560101370115712320898754509054993314536417523566518362136876") + PrecomputedLines[1][33].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2243909572679435242054837525492887593116525671610805132367202750543401085997072583661340355528884399865342689858236") + PrecomputedLines[1][33].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3479899708432584697662617257428806293863393731873139669667415845090665473410343628859481184318696801412878547924785") + // i = 32 + PrecomputedLines[0][32].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3132063229600917670613891136589922564588967418494067208567202658993587612721110746333109603494327519775953076104661") + PrecomputedLines[0][32].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1275373167524064736373612114593617177217617901632801154941510388392008224603906384641920546232940016648324112284352") + PrecomputedLines[1][32].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1951281490834827378430403879656313795704179094690571238995959924602423201300568801597213153338556343940337082564163") + PrecomputedLines[1][32].A1 = emulated.ValueOf[emulated.BLS12381Fp]("874739535077666291108709093692947749877527789473195415417438679809595046593449884845499149689457088132546512367865") + // i = 31 + PrecomputedLines[0][31].A0 = emulated.ValueOf[emulated.BLS12381Fp]("86053440242818054368053994258080131501108197513542421797885566985997737396079590383222681554582578024765554005575") + PrecomputedLines[0][31].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3230634811109724887953064878272037919734346474897928436938220980193233349186994534217923136215960889197535862790414") + PrecomputedLines[1][31].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1578872153108485586740648390134492363858521364266818906972887979762709384757209232552002419238395578707827072494146") + PrecomputedLines[1][31].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3346182034080587137847082053075136548616290128507231616924000212671337956860278775275340632986689511391294754917950") + // i = 30 + PrecomputedLines[0][30].A0 = emulated.ValueOf[emulated.BLS12381Fp]("563606101962267088332267732828721041043582587025246353580116346987890608083331016733410427361695949570965227520789") + PrecomputedLines[0][30].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1447710510444163006794573556490053458339113778154087329237293023456254973340533160850843017629474417093574674355164") + PrecomputedLines[1][30].A0 = emulated.ValueOf[emulated.BLS12381Fp]("271208876921370646099156956119790748292096870574678731086094113909284881874681676160093634310679644399904203270909") + PrecomputedLines[1][30].A1 = emulated.ValueOf[emulated.BLS12381Fp]("734014135075220340217360055586287717623361723625999277128094287557651444748005340858829326263676471443653917380028") + // i = 29 + PrecomputedLines[0][29].A0 = emulated.ValueOf[emulated.BLS12381Fp]("553104808200198436637099795198248457859794790285185304908007586812710981503044298036311960523402366230566783003653") + PrecomputedLines[0][29].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2003955291719445647798366785133721961713094985232133195662878880175842141635740179320897168340231736128384133534354") + PrecomputedLines[1][29].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2178147651395974400304695746874894783920829064896437782855376270213935679743301253682869488818132822182583360124241") + PrecomputedLines[1][29].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2021924461162721359703592906360932111888647290905383591545198703128435999832129964838104126457865901358842626541451") + // i = 28 + PrecomputedLines[0][28].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1777745173717667099342756119046464257151872581855146300331383956582867311823203204529781655612736609774045268711658") + PrecomputedLines[0][28].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1041287756528278863695913064899216705646770322785896202378661226081642192837187951527490143718478781509638589819902") + PrecomputedLines[1][28].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1967355981600226768199599440345674395956405622419425164671491474464778947731825391685654108331113188849139873946423") + PrecomputedLines[1][28].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3256694521155602227991873365090666099769483235980732713464480092130836478909542237706421755139181299068417733946618") + // i = 27 + PrecomputedLines[0][27].A0 = emulated.ValueOf[emulated.BLS12381Fp]("152937177363112535899008097612480126842969909709502300529346908402930035839474413363166062865253708049557944378422") + PrecomputedLines[0][27].A1 = emulated.ValueOf[emulated.BLS12381Fp]("897596282626309905062469669236886665340011419494984372382213783173084576758483513318145997370160010502859725827132") + PrecomputedLines[1][27].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2346758308810875365168118646648040744389722928738625003408785886492711195838395334630420578844405693367496418920211") + PrecomputedLines[1][27].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2656981446205865501558128157203546483094652284695248391187003029925483050287333375605638675631123261261698615316370") + // i = 26 + PrecomputedLines[0][26].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3620566506706159804708100239488627313869689908825766518776558079707132636874047944658510407093741323189240311926227") + PrecomputedLines[0][26].A1 = emulated.ValueOf[emulated.BLS12381Fp]("587825566103826950433417267558764839534823650855845168865859272218400707217937201097874915851900418500280049177693") + PrecomputedLines[1][26].A0 = emulated.ValueOf[emulated.BLS12381Fp]("315851704572083517038658574619778523404095735851866846909808016842341190792198628387184169216243993584352479330528") + PrecomputedLines[1][26].A1 = emulated.ValueOf[emulated.BLS12381Fp]("229972078916735518731751050710560968941984639445244488397477314523703531099960305941377239301202435021499588716423") + // i = 25 + PrecomputedLines[0][25].A0 = emulated.ValueOf[emulated.BLS12381Fp]("979555040667686144734321069717744602937606084006969482488036046063313229494574591630057601636979700665914802203854") + PrecomputedLines[0][25].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1083328512005369786522146909067018806083197802603736575975871992537930820832853483250551205803668495310640749336142") + PrecomputedLines[1][25].A0 = emulated.ValueOf[emulated.BLS12381Fp]("792214870461187352036421584558521612384249649331727031805894500139426113040621396016975323081727321304814602705639") + PrecomputedLines[1][25].A1 = emulated.ValueOf[emulated.BLS12381Fp]("293126960439059888314336502484927086559096380766871756432786584741451879297750115116060526063096689019596377256462") + // i = 24 + PrecomputedLines[0][24].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3023305939043287566376775840182397466273202632600006671914074027221173655043488882180663551326278326566212930523605") + PrecomputedLines[0][24].A1 = emulated.ValueOf[emulated.BLS12381Fp]("211345580778501530278518067818080421373636225064710189149613966659012842944717863500047454846418617510601303495382") + PrecomputedLines[1][24].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1083967895923493980361373119038766728523388392678856516526946919892547528455219127949867777528733801974645664040222") + PrecomputedLines[1][24].A1 = emulated.ValueOf[emulated.BLS12381Fp]("27259736067175495975204418838888242237774672904102535395102952871757057779325470610149549006171411043997579252843") + // i = 23 + PrecomputedLines[0][23].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3524508898240761195109795597763114367092144781256655615102582904100668274549691032362784563296727499220983570152441") + PrecomputedLines[0][23].A1 = emulated.ValueOf[emulated.BLS12381Fp]("154109098599910912999381594862896628943756799853257522903745986263346741806885494047064041211163702319803554533636") + PrecomputedLines[1][23].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1507508929647441259276327449399270586674159083549117937291801300548797928248720243445837206614212644609961366620912") + PrecomputedLines[1][23].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1868673061074969251396600047373562479665360443636499083160191471647935542691413181537129994029355285946490530299296") + // i = 22 + PrecomputedLines[0][22].A0 = emulated.ValueOf[emulated.BLS12381Fp]("908014978798891590190712328207935319381397213593311780079258255744362291021841393725131242015616027170757756445593") + PrecomputedLines[0][22].A1 = emulated.ValueOf[emulated.BLS12381Fp]("200639170601284184001774623835371352522039806052768444417749601795225843628193407696221357162246810418514142831843") + PrecomputedLines[1][22].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2350371217279740579820143271264453949977038917578227565960325061648448714010069328159980890814283501689783961524702") + PrecomputedLines[1][22].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1593895841197462027228582101998016460630082782005665672367357528349861944810760713747730739706721098594296662524458") + // i = 21 + PrecomputedLines[0][21].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2360461348846359806382917405942817466536487234848314274910044461635090577997989305621192270573165846745607113076841") + PrecomputedLines[0][21].A1 = emulated.ValueOf[emulated.BLS12381Fp]("947341937929141884122447526020770991735996295941158962518194558209791699726852221249234839377751905155930729993341") + PrecomputedLines[1][21].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2371915419268388555770713628437540544538124706813717269034056295021127101714581368189393845415492802863632836193876") + PrecomputedLines[1][21].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3820176191003385401211443983255159967207554116103848359429328989702264327090629066820234833787173268959129598582288") + // i = 20 + PrecomputedLines[0][20].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1572650153562580171679461063610490558238392713851486235102056983067440099520233982149463624217641871971604769598051") + PrecomputedLines[0][20].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2618980509891451881049009354979162559943307077780444024131691302777717224275262406075437879469156653370163741776352") + PrecomputedLines[1][20].A0 = emulated.ValueOf[emulated.BLS12381Fp]("914440712713401078461992974124539925054406459556882995700401997575824770950958824081091612726953802622133206086132") + PrecomputedLines[1][20].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3761729476045264945742115214328854993283781354642998410783435809759489218467630853313060142896942920368784023624564") + // i = 19 + PrecomputedLines[0][19].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3121648410086279280187939190420325208219240333621791237185060309660921657630544662333480321336847899031625344257388") + PrecomputedLines[0][19].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2427620357688217067501558172333787193041787904721169429351785053557711405904561884061991426620314542234334857364382") + PrecomputedLines[1][19].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2708786771252814590617678344796738582404650900020338773987915299039612698556736430693452560828818771063820703207467") + PrecomputedLines[1][19].A1 = emulated.ValueOf[emulated.BLS12381Fp]("517411305194592772809763582146889163055266117798808949181626918347108724425656691661675351924493197452027981397318") + // i = 18 + PrecomputedLines[0][18].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2517225236557550934850433410068330464823062596484997219643623944432723816139315676042622268129511422705552411378459") + PrecomputedLines[0][18].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2316477955778739510775794399658217674719353791845398241574961924479405948992154924007863632495072289492221251858022") + PrecomputedLines[1][18].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3355942814100146198477737240362747495473665280368926326861352687278625310382944378160190997562974389721544171232057") + PrecomputedLines[1][18].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1405256119069207071356021656215306373736938017990909711990090738437615436721191173454389898556042490086922502478452") + // i = 17 + PrecomputedLines[0][17].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1871012489878727224876595505058665788857585521570986653035860338904253701870605154927982874728131181637596420475443") + PrecomputedLines[0][17].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1901160185068128975791454167192809834079447334201748146854907081248444189191761875810491917795581071947191880288163") + PrecomputedLines[1][17].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1508805355272822827788090744062050727018818455929522068131059654827967915946146897794939483471020907220031424503902") + PrecomputedLines[1][17].A1 = emulated.ValueOf[emulated.BLS12381Fp]("392176457381240086084667193945896246039077028491953999599542527450836164540698798542367795135108835354979937729176") + // i = 16 + PrecomputedLines[0][16].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2741229165070269117924902999145560172156178221528003351492036613433272629200676282018955648591068110182854114550924") + PrecomputedLines[0][16].A1 = emulated.ValueOf[emulated.BLS12381Fp]("394754449245515942947222210401725628247299909612001534445659911053162507221378908343285116791344668589530229643106") + PrecomputedLines[1][16].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3281821044964510076704259445805281918044215707355772313190483116134259563285207680309643894942896092949741436198183") + PrecomputedLines[1][16].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2279727743915060359386920658508386508396853897423119737675172726894449285404134596174496441719084502944932053797831") + PrecomputedLines[2][16].A0 = emulated.ValueOf[emulated.BLS12381Fp]("353345079382094137526926194700107849760485245746200219086878713290578018072683764914127968225936663626062935997429") + PrecomputedLines[2][16].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1564619051727543092268240068918461744410102185232141167890515232410738747169856682086977412134170322374787182604285") + PrecomputedLines[3][16].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1031783123982449084107331586608935203701652584407973805774838114682364185882412466765964240680264636252893789921832") + PrecomputedLines[3][16].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1674505309126597529030651480135394381023726406881712713659095642254644276683224627684705587858451842527993474727715") + // i = 15 + PrecomputedLines[0][15].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2456248725387547397262525114643040834056469397193019364301066598073765581121143668240981415630210804907944829744944") + PrecomputedLines[0][15].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1043012900187598140574711455572545751706624394067345211899147242871854241505831150194052069901037949851965337396572") + PrecomputedLines[1][15].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2818047063092228393255521223699593672793582976332754074700446958296103353245361019407731583187313322019846120441360") + PrecomputedLines[1][15].A1 = emulated.ValueOf[emulated.BLS12381Fp]("609324778570449498894287960818577137803498106249221775461586760331332978890739893282037207179739452898113233757842") + // i = 14 + PrecomputedLines[0][14].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1583907611945126402969342158670251457555916265122712423938192754774277188592755130719033171706996928741333176355259") + PrecomputedLines[0][14].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1480043754732647907528102591669853624107177184606806464768162240804848473641199943564132307414898940662237778707880") + PrecomputedLines[1][14].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1073770434466187296401264518401836490010154190466732406500178370991399776232002701153045477118496857622558404986124") + PrecomputedLines[1][14].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2040088101627138737987871617873915758242178143051551613106980725563415248571813669113484542372670569201005661094588") + // i = 13 + PrecomputedLines[0][13].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3611482998132061804020855119834275656700118225321062337009879540758621588230203253953916980074315068068764047335052") + PrecomputedLines[0][13].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2346170214863078560104862811778279655418099426832641872989862189028138184629798194757339729949883260228219986435351") + PrecomputedLines[1][13].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1200357469471453477416622874970904427768324541884587100292052206488927915687048371988316685540349619282470657037061") + PrecomputedLines[1][13].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3484088805784391480522749040297723013103515240980206599299861821416786877289405050312590194417722749870526191019673") + // i = 12 + PrecomputedLines[0][12].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1175103434969115663465243478736530988428094788060063630270795460718911007942618967214053003288533838183840301794166") + PrecomputedLines[0][12].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2153876779718109562719530248580382631701336274595917146227223627547555040502114589540617386615636683082237211673265") + PrecomputedLines[1][12].A0 = emulated.ValueOf[emulated.BLS12381Fp]("196906826244106976903807716766844478488002154571305434677738804785501147563736236264422551863677725322188941142814") + PrecomputedLines[1][12].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3003134029129368381569085969742217191347104600214895562498282523853129629358228016925981694439115028848189047930274") + // i = 11 + PrecomputedLines[0][11].A0 = emulated.ValueOf[emulated.BLS12381Fp]("288579508917537216032217176846572026112298341738458834446349503424029838166869418511639725695422373336634256356505") + PrecomputedLines[0][11].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3992933519722599233380063922132718381812819049463872046892870883016443953054792323020456056344685398204485574386423") + PrecomputedLines[1][11].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1901084398218251517716150525360269215916704505616118339628161776708362087441629181428748618965084417219509574513576") + PrecomputedLines[1][11].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1952189648467839796870124474683061875216340850091397726737866109626602833143447824132199531004749663678290861592934") + // i = 10 + PrecomputedLines[0][10].A0 = emulated.ValueOf[emulated.BLS12381Fp]("674049131642920716773292998275635125430185488384219499251573331404631012361307801528514714700059774333017906904071") + PrecomputedLines[0][10].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2986041691645477878693593345744447039959458233932414756236894338913165710653579464932153470232310464980926622514460") + PrecomputedLines[1][10].A0 = emulated.ValueOf[emulated.BLS12381Fp]("396980881920070119281371625216467857543862821632618429299867864848221772789254174263915490714491686803960122676841") + PrecomputedLines[1][10].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3363488693200112955217599550745786026344253629302768961678458665321585247924625609690843453377126238012557561047704") + // i = 9 + PrecomputedLines[0][9].A0 = emulated.ValueOf[emulated.BLS12381Fp]("903980085139727083659669247383921724226476199049306205220045968334604361831284928601247223629870059814604444022210") + PrecomputedLines[0][9].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1347827538635308235340411256030642664724853897798539165023305691688910037429723163256786901153134931835025303767837") + PrecomputedLines[1][9].A0 = emulated.ValueOf[emulated.BLS12381Fp]("362946095859932580358415564370912182581787648319175363986622115682210008263959828660939888692668763890646644654633") + PrecomputedLines[1][9].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1521162638370174260915444190906972864974462580865371310106941096343692947339423700927310107639026264108572639880763") + // i = 8 + PrecomputedLines[0][8].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1451288507938675216820664191752001443661941658972340336364138619052459160919783010387521742272982958657081720589057") + PrecomputedLines[0][8].A1 = emulated.ValueOf[emulated.BLS12381Fp]("780091844842754045057312647019433725300884829443248665854800094483946888790564309147834461359778064434157512664965") + PrecomputedLines[1][8].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2025397797410857541801387239105980546331437474914109199821350959051310813036349354931803123302652471034747102363943") + PrecomputedLines[1][8].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1349804482686745175789105413999273853512844505258621270714326733460020468417764035353237622608805002148395864268050") + // i = 7 + PrecomputedLines[0][7].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3108259519725255622160865392868985983941931146911727818694357361341016397681604262681477576824444998720748752183853") + PrecomputedLines[0][7].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2746312800933380638719051663688878535959804042913702866743872475497377391010820849783618378511743189938152475719597") + PrecomputedLines[1][7].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2080362443193479154911701830952814560889592049737440328591160551538249466843193733836470398280062259043136979920854") + PrecomputedLines[1][7].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3964650959705446256700736846818753122873391146832851772353426653959701656164345607420376069170669631695894867766661") + // i = 6 + PrecomputedLines[0][6].A0 = emulated.ValueOf[emulated.BLS12381Fp]("998401711577030801499823016548091904661868235334825752872616271090737140683615722295667334501710630177910370409188") + PrecomputedLines[0][6].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3113580388271795111973808218003838749115342656107628938242379127409951356994761702310134652453351993186439880409956") + PrecomputedLines[1][6].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2209895036333273778785791339299432698003341556381168283301669691693492869926045293967113459553699553056033465779798") + PrecomputedLines[1][6].A1 = emulated.ValueOf[emulated.BLS12381Fp]("959647019662444226972580453114544223706755036638831753478672703623585644987424105312187217218661108832975493877168") + // i = 5 + PrecomputedLines[0][5].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2841471010438431779967254954386974501546987002770608809461472519141013537463963266051229253731270768522508156999616") + PrecomputedLines[0][5].A1 = emulated.ValueOf[emulated.BLS12381Fp]("696566894503412538642690079753767069302913336129171106478522752438805851287179825258314756117624623080809738255414") + PrecomputedLines[1][5].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2911265147164022514710363886878608210045172248737826097784218655566417385178563726914426070452767585831687080991147") + PrecomputedLines[1][5].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2744818944054579004598792944387297677957805313049835664087933533436786259115891347376732045096926518910859165808954") + // i = 4 + PrecomputedLines[0][4].A0 = emulated.ValueOf[emulated.BLS12381Fp]("2626722221901546525450495131546494365497693562193436563425832546273615525185343652803128707511796650054409580651607") + PrecomputedLines[0][4].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2898515900951479363905069458547895124826329747625108263493772297820331408153856994348030852520595550159245759652161") + PrecomputedLines[1][4].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1921592673575952646060010989725488734583736903852633912443953723136259718959911482580429310015710706499067320270317") + PrecomputedLines[1][4].A1 = emulated.ValueOf[emulated.BLS12381Fp]("956251362899975176731442839970722287994171762098273356212496480881340499428283976466263674939566948363663432606399") + // i = 3 + PrecomputedLines[0][3].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3921602381215869979684331626432646546716509921608717347721382117849039021022887606324288626301870542752693783471927") + PrecomputedLines[0][3].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2523702364140531959195061579572246079207345591008472492207286834773077913074321681802231395020396532561725831993635") + PrecomputedLines[1][3].A0 = emulated.ValueOf[emulated.BLS12381Fp]("365550321315950743851268167652946217061103742807645488424195332015419615371374089484470568460278262393265512170365") + PrecomputedLines[1][3].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2367775129286209205281261235584543398195849023700633569371476648532398963866035316135040295345655465417403167550696") + // i = 2 + PrecomputedLines[0][2].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1680458512589067571896634706380419213249047340047151052382898514607798507250939135807299863867577040318561590393585") + PrecomputedLines[0][2].A1 = emulated.ValueOf[emulated.BLS12381Fp]("1033946335839836561404870194835625676362202728567409441052246502652005038621446258562651084280524280496231515354305") + PrecomputedLines[1][2].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3777662686667655344658815583791672223988649461202850693112863279144808037867254199316045088493867424310466772435262") + PrecomputedLines[1][2].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2064839480899148857837309570504205604502488145969799920908269678849234113812590058688143649401420761848749422330662") + // i = 1 + PrecomputedLines[0][1].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3565448959257374800198419850448580122970161672942598117316274721696483682407630545127619505130456918163042800666213") + PrecomputedLines[0][1].A1 = emulated.ValueOf[emulated.BLS12381Fp]("3345751791634611319739837464169473359182614679402518952036363106531863475124919612432455420087836710231316659292948") + PrecomputedLines[1][1].A0 = emulated.ValueOf[emulated.BLS12381Fp]("1617681413845042635925097980110882198826532425223057956385624726626794238122669810593496494648798312736702002674001") + PrecomputedLines[1][1].A1 = emulated.ValueOf[emulated.BLS12381Fp]("623718078700217167207278824799657763292593672163484469137343331952443014063067608809944893696957512085202161610194") + // i = 0 + PrecomputedLines[0][0].A0 = emulated.ValueOf[emulated.BLS12381Fp]("3021393366864267596780779365081014671914599573258232428274167592862693896278193045742059529785963318957699876618598") + PrecomputedLines[0][0].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2774709764298624490594173467667228199179204402302427867108160491882432826696994080161627641500497330778328523750634") + PrecomputedLines[1][0].A0 = emulated.ValueOf[emulated.BLS12381Fp]("887601706871077344640254225359406175214680885957404575459587548760809083522151140664098119031225155506199574629967") + PrecomputedLines[1][0].A1 = emulated.ValueOf[emulated.BLS12381Fp]("2676643094794718705520885263561335025726613452251214882260540192490012199952733082631183814711038861520905961974067") + + return PrecomputedLines +} diff --git a/std/algebra/emulated/sw_bn254/doc.go b/std/algebra/emulated/sw_bn254/doc.go new file mode 100644 index 0000000000..e66e237ee9 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/doc.go @@ -0,0 +1,6 @@ +// Package sw_bn254 implements G1 and G2 arithmetics and pairing computation over BN254 curve. +// +// The implementation follows [[Housni22]]: "Pairings in Rank-1 Constraint Systems". +// +// [Housni22]: https://eprint.iacr.org/2022/1162 +package sw_bn254 diff --git a/std/algebra/emulated/sw_bn254/doc_test.go b/std/algebra/emulated/sw_bn254/doc_test.go new file mode 100644 index 0000000000..7d8ef6a6cd --- /dev/null +++ b/std/algebra/emulated/sw_bn254/doc_test.go @@ -0,0 +1,105 @@ +package sw_bn254_test + +import ( + "crypto/rand" + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" +) + +type PairCircuit struct { + InG1 sw_bn254.G1Affine + InG2 sw_bn254.G2Affine + Res sw_bn254.GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := sw_bn254.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + // Pair method does not check that the points are in the proper groups. + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + // Compute the pairing + res, err := pairing.Pair([]*sw_bn254.G1Affine{&c.InG1}, []*sw_bn254.G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func ExamplePairing() { + p, q, err := randomG1G2Affines() + if err != nil { + panic(err) + } + res, err := bn254.Pair([]bn254.G1Affine{p}, []bn254.G2Affine{q}) + if err != nil { + panic(err) + } + circuit := PairCircuit{} + witness := PairCircuit{ + InG1: sw_bn254.NewG1Affine(p), + InG2: sw_bn254.NewG2Affine(q), + Res: sw_bn254.NewGTEl(res), + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } +} + +func randomG1G2Affines() (p bn254.G1Affine, q bn254.G2Affine, err error) { + _, _, G1AffGen, G2AffGen := bn254.Generators() + mod := bn254.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + return p, q, err + } + p.ScalarMultiplication(&G1AffGen, s1) + q.ScalarMultiplication(&G2AffGen, s2) + return +} diff --git a/std/algebra/emulated/sw_bn254/g1.go b/std/algebra/emulated/sw_bn254/g1.go new file mode 100644 index 0000000000..69ce54898c --- /dev/null +++ b/std/algebra/emulated/sw_bn254/g1.go @@ -0,0 +1,16 @@ +package sw_bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type G1Affine = sw_emulated.AffinePoint[emulated.BN254Fp] + +func NewG1Affine(v bn254.G1Affine) G1Affine { + return G1Affine{ + X: emulated.ValueOf[emulated.BN254Fp](v.X), + Y: emulated.ValueOf[emulated.BN254Fp](v.Y), + } +} diff --git a/std/algebra/emulated/sw_bn254/g2.go b/std/algebra/emulated/sw_bn254/g2.go new file mode 100644 index 0000000000..71026f66f0 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/g2.go @@ -0,0 +1,216 @@ +package sw_bn254 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" + "github.com/consensys/gnark/std/math/emulated" +) + +type G2 struct { + *fields_bn254.Ext2 + w *emulated.Element[emulated.BN254Fp] + u, v *fields_bn254.E2 +} + +type G2Affine struct { + X, Y fields_bn254.E2 +} + +func NewG2(api frontend.API) *G2 { + w := emulated.ValueOf[emulated.BN254Fp]("21888242871839275220042445260109153167277707414472061641714758635765020556616") + u := fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp]("21575463638280843010398324269430826099269044274347216827212613867836435027261"), + A1: emulated.ValueOf[emulated.BN254Fp]("10307601595873709700152284273816112264069230130616436755625194854815875713954"), + } + v := fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp]("2821565182194536844548159561693502659359617185244120367078079554186484126554"), + A1: emulated.ValueOf[emulated.BN254Fp]("3505843767911556378687030309984248845540243509899259641013678093033130930403"), + } + return &G2{ + Ext2: fields_bn254.NewExt2(api), + w: &w, + u: &u, + v: &v, + } +} + +func NewG2Affine(v bn254.G2Affine) G2Affine { + return G2Affine{ + X: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.X.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.X.A1), + }, + Y: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.Y.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.Y.A1), + }, + } +} + +func (g2 *G2) phi(q *G2Affine) *G2Affine { + x := g2.Ext2.MulByElement(&q.X, g2.w) + + return &G2Affine{ + X: *x, + Y: q.Y, + } +} + +func (g2 *G2) psi(q *G2Affine) *G2Affine { + x := g2.Ext2.Conjugate(&q.X) + x = g2.Ext2.Mul(x, g2.u) + y := g2.Ext2.Conjugate(&q.Y) + y = g2.Ext2.Mul(y, g2.v) + + return &G2Affine{ + X: *x, + Y: *y, + } +} + +func (g2 *G2) scalarMulBySeed(q *G2Affine) *G2Affine { + z := g2.double(q) + t0 := g2.add(q, z) + t2 := g2.add(q, t0) + t1 := g2.add(z, t2) + z = g2.doubleAndAdd(t1, t0) + t0 = g2.add(t0, z) + t2 = g2.add(t2, t0) + t1 = g2.add(t1, t2) + t0 = g2.add(t0, t1) + t1 = g2.add(t1, t0) + t0 = g2.add(t0, t1) + t2 = g2.add(t2, t0) + t1 = g2.doubleAndAdd(t2, t1) + t2 = g2.add(t2, t1) + z = g2.add(z, t2) + t2 = g2.add(t2, z) + z = g2.doubleAndAdd(t2, z) + t0 = g2.add(t0, z) + t1 = g2.add(t1, t0) + t3 := g2.double(t1) + t3 = g2.doubleAndAdd(t3, t1) + t2 = g2.add(t2, t3) + t1 = g2.add(t1, t2) + t2 = g2.add(t2, t1) + t2 = g2.doubleN(t2, 16) + t1 = g2.doubleAndAdd(t2, t1) + t1 = g2.doubleN(t1, 13) + t0 = g2.doubleAndAdd(t1, t0) + t0 = g2.doubleN(t0, 15) + z = g2.doubleAndAdd(t0, z) + + return z +} + +func (g2 G2) add(p, q *G2Affine) *G2Affine { + // compute λ = (q.y-p.y)/(q.x-p.x) + qypy := g2.Ext2.Sub(&q.Y, &p.Y) + qxpx := g2.Ext2.Sub(&q.X, &p.X) + λ := g2.Ext2.DivUnchecked(qypy, qxpx) + + // xr = λ²-p.x-q.x + λλ := g2.Ext2.Square(λ) + qxpx = g2.Ext2.Add(&p.X, &q.X) + xr := g2.Ext2.Sub(λλ, qxpx) + + // p.y = λ(p.x-r.x) - p.y + pxrx := g2.Ext2.Sub(&p.X, xr) + λpxrx := g2.Ext2.Mul(λ, pxrx) + yr := g2.Ext2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) neg(p *G2Affine) *G2Affine { + xr := &p.X + yr := g2.Ext2.Neg(&p.Y) + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 G2) sub(p, q *G2Affine) *G2Affine { + qNeg := g2.neg(q) + return g2.add(p, qNeg) +} + +func (g2 *G2) double(p *G2Affine) *G2Affine { + // compute λ = (3p.x²)/2*p.y + xx3a := g2.Square(&p.X) + xx3a = g2.MulByConstElement(xx3a, big.NewInt(3)) + y2 := g2.Double(&p.Y) + λ := g2.DivUnchecked(xx3a, y2) + + // xr = λ²-2p.x + x2 := g2.Double(&p.X) + λλ := g2.Square(λ) + xr := g2.Sub(λλ, x2) + + // yr = λ(p-xr) - p.y + pxrx := g2.Sub(&p.X, xr) + λpxrx := g2.Mul(λ, pxrx) + yr := g2.Sub(λpxrx, &p.Y) + + return &G2Affine{ + X: *xr, + Y: *yr, + } +} + +func (g2 *G2) doubleN(p *G2Affine, n int) *G2Affine { + pn := p + for s := 0; s < n; s++ { + pn = g2.double(pn) + } + return pn +} + +func (g2 G2) doubleAndAdd(p, q *G2Affine) *G2Affine { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := g2.Ext2.Sub(&q.Y, &p.Y) + xqxp := g2.Ext2.Sub(&q.X, &p.X) + λ1 := g2.Ext2.DivUnchecked(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := g2.Ext2.Square(λ1) + xqxp = g2.Ext2.Add(&p.X, &q.X) + x2 := g2.Ext2.Sub(λ1λ1, xqxp) + + // ommit y2 computation + // compute λ2 = -λ1-2*p.y/(x2-p.x) + ypyp := g2.Ext2.Add(&p.Y, &p.Y) + x2xp := g2.Ext2.Sub(x2, &p.X) + λ2 := g2.Ext2.DivUnchecked(ypyp, x2xp) + λ2 = g2.Ext2.Add(λ1, λ2) + λ2 = g2.Ext2.Neg(λ2) + + // compute x3 =λ2²-p.x-x3 + λ2λ2 := g2.Ext2.Square(λ2) + x3 := g2.Ext2.Sub(λ2λ2, &p.X) + x3 = g2.Ext2.Sub(x3, x2) + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := g2.Ext2.Sub(&p.X, x3) + y3 = g2.Ext2.Mul(λ2, y3) + y3 = g2.Ext2.Sub(y3, &p.Y) + + return &G2Affine{ + X: *x3, + Y: *y3, + } +} + +// AssertIsEqual asserts that p and q are the same point. +func (g2 *G2) AssertIsEqual(p, q *G2Affine) { + g2.Ext2.AssertIsEqual(&p.X, &q.X) + g2.Ext2.AssertIsEqual(&p.Y, &q.Y) +} diff --git a/std/algebra/emulated/sw_bn254/g2_test.go b/std/algebra/emulated/sw_bn254/g2_test.go new file mode 100644 index 0000000000..3e61f05cc4 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/g2_test.go @@ -0,0 +1,144 @@ +package sw_bn254 + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type addG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *addG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.add(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bn254.G2Affine + res.Add(&in1, &in2) + witness := addG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleG2Circuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *doubleG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.double(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bn254.G2Affine + var in1Jac, resJac bn254.G2Jac + in1Jac.FromAffine(&in1) + resJac.Double(&in1Jac) + res.FromJacobian(&resJac) + witness := doubleG2Circuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type doubleAndAddG2Circuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *doubleAndAddG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.doubleAndAdd(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoubleAndAddG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bn254.G2Affine + res.Double(&in1). + Add(&res, &in2) + witness := doubleAndAddG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&doubleAndAddG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type scalarMulG2BySeedCircuit struct { + In1 G2Affine + Res G2Affine +} + +func (c *scalarMulG2BySeedCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.scalarMulBySeed(&c.In1) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestScalarMulG2BySeedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bn254.G2Affine + x0, _ := new(big.Int).SetString("4965661367192848881", 10) + res.ScalarMultiplication(&in1, x0) + witness := scalarMulG2BySeedCircuit{ + In1: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2BySeedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type endomorphismG2Circuit struct { + In1 G2Affine +} + +func (c *endomorphismG2Circuit) Define(api frontend.API) error { + g2 := NewG2(api) + res1 := g2.phi(&c.In1) + res1 = g2.neg(res1) + res2 := g2.psi(&c.In1) + res2 = g2.psi(res2) + g2.AssertIsEqual(res1, res2) + return nil +} + +func TestEndomorphismG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + witness := endomorphismG2Circuit{ + In1: NewG2Affine(in1), + } + err := test.IsSolved(&endomorphismG2Circuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/sw_bn254/pairing.go b/std/algebra/emulated/sw_bn254/pairing.go new file mode 100644 index 0000000000..7b9bf43013 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/pairing.go @@ -0,0 +1,1033 @@ +package sw_bn254 + +import ( + "errors" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +type Pairing struct { + api frontend.API + *fields_bn254.Ext12 + curveF *emulated.Field[emulated.BN254Fp] + curve *sw_emulated.Curve[emulated.BN254Fp, emulated.BN254Fr] + g2 *G2 + bTwist *fields_bn254.E2 + lines [4][67]fields_bn254.E2 +} + +type GTEl = fields_bn254.E12 + +func NewGTEl(v bn254.GT) GTEl { + return GTEl{ + C0: fields_bn254.E6{ + B0: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C0.B0.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C0.B0.A1), + }, + B1: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C0.B1.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C0.B1.A1), + }, + B2: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C0.B2.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C0.B2.A1), + }, + }, + C1: fields_bn254.E6{ + B0: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C1.B0.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C1.B0.A1), + }, + B1: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C1.B1.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C1.B1.A1), + }, + B2: fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp](v.C1.B2.A0), + A1: emulated.ValueOf[emulated.BN254Fp](v.C1.B2.A1), + }, + }, + } +} + +func NewPairing(api frontend.API) (*Pairing, error) { + ba, err := emulated.NewField[emulated.BN254Fp](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + return nil, fmt.Errorf("new curve: %w", err) + } + bTwist := fields_bn254.E2{ + A0: emulated.ValueOf[emulated.BN254Fp]("19485874751759354771024239261021720505790618469301721065564631296452457478373"), + A1: emulated.ValueOf[emulated.BN254Fp]("266929791119991161246907387137283842545076965332900288569378510910307636690"), + } + return &Pairing{ + api: api, + Ext12: fields_bn254.NewExt12(api), + curveF: ba, + curve: curve, + g2: NewG2(api), + bTwist: &bTwist, + lines: getPrecomputeLines(), + }, nil +} + +// FinalExponentiation computes the exponentiation eᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r. +// +// We use instead d'= s ⋅ d, where s is the cofactor +// +// 2x₀(6x₀²+3x₀+1) +// +// and r does NOT divide d' +// +// FinalExponentiation returns a decompressed element in E12. +// +// This is the safe version of the method where e may be {-1,1}. If it is known +// that e ≠ {-1,1} then using the unsafe version of the method saves +// considerable amount of constraints. When called with the result of +// [MillerLoop], then current method is applicable when length of the inputs to +// Miller loop is 1. +func (pr Pairing) FinalExponentiation(e *GTEl) *GTEl { + return pr.finalExponentiation(e, false) +} + +// FinalExponentiationUnsafe computes the exponentiation eᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r. +// +// We use instead d'= s ⋅ d, where s is the cofactor +// +// 2x₀(6x₀²+3x₀+1) +// +// and r does NOT divide d' +// +// FinalExponentiationUnsafe returns a decompressed element in E12. +// +// This is the unsafe version of the method where e may NOT be {-1,1}. If e ∈ +// {-1, 1}, then there exists no valid solution to the circuit. This method is +// applicable when called with the result of [MillerLoop] method when the length +// of the inputs to Miller loop is 1. +func (pr Pairing) FinalExponentiationUnsafe(e *GTEl) *GTEl { + return pr.finalExponentiation(e, true) +} + +// finalExponentiation computes the exponentiation eᵈ where +// +// d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r. +// +// We use instead d'= s ⋅ d, where s is the cofactor +// +// 2x₀(6x₀²+3x₀+1) +// +// and r does NOT divide d' +// +// finalExponentiation returns a decompressed element in E12 +func (pr Pairing) finalExponentiation(e *GTEl, unsafe bool) *GTEl { + + // 1. Easy part + // (p⁶-1)(p²+1) + var selector1, selector2 frontend.Variable + _dummy := pr.Ext6.One() + + if unsafe { + // The Miller loop result is ≠ {-1,1}, otherwise this means P and Q are + // linearly dependant and not from G1 and G2 respectively. + // So e ∈ G_{q,2} \ {-1,1} and hence e.C1 ≠ 0. + // Nothing to do. + + } else { + // However, for a product of Miller loops (n>=2) this might happen. If this is + // the case, the result is 1 in the torus. We assign a dummy value (1) to e.C1 + // and proceed further. + selector1 = pr.Ext6.IsZero(&e.C1) + e.C1 = *pr.Ext6.Select(selector1, _dummy, &e.C1) + } + + // Torus compression absorbed: + // Raising e to (p⁶-1) is + // e^(p⁶) / e = (e.C0 - w*e.C1) / (e.C0 + w*e.C1) + // = (-e.C0/e.C1 + w) / (-e.C0/e.C1 - w) + // So the fraction -e.C0/e.C1 is already in the torus. + // This absorbs the torus compression in the easy part. + c := pr.Ext6.DivUnchecked(&e.C0, &e.C1) + c = pr.Ext6.Neg(c) + t0 := pr.FrobeniusSquareTorus(c) + c = pr.MulTorus(t0, c) + + // 2. Hard part (up to permutation) + // 2x₀(6x₀²+3x₀+1)(p⁴-p²+1)/r + // Duquesne and Ghammam + // https://eprint.iacr.org/2015/192.pdf + // Fuentes et al. (alg. 6) + // performed in torus compressed form + t0 = pr.ExptTorus(c) + t0 = pr.InverseTorus(t0) + t0 = pr.SquareTorus(t0) + t1 := pr.SquareTorus(t0) + t1 = pr.MulTorus(t0, t1) + t2 := pr.ExptTorus(t1) + t2 = pr.InverseTorus(t2) + t3 := pr.InverseTorus(t1) + t1 = pr.MulTorus(t2, t3) + t3 = pr.SquareTorus(t2) + t4 := pr.ExptTorus(t3) + t4 = pr.MulTorus(t1, t4) + t3 = pr.MulTorus(t0, t4) + t0 = pr.MulTorus(t2, t4) + t0 = pr.MulTorus(c, t0) + t2 = pr.FrobeniusTorus(t3) + t0 = pr.MulTorus(t2, t0) + t2 = pr.FrobeniusSquareTorus(t4) + t0 = pr.MulTorus(t2, t0) + t2 = pr.InverseTorus(c) + t2 = pr.MulTorus(t2, t3) + t2 = pr.FrobeniusCubeTorus(t2) + + var result GTEl + // MulTorus(t0, t2) requires t0 ≠ -t2. When t0 = -t2, it means the + // product is 1 in the torus. + if unsafe { + // For a single pairing, this does not happen because the pairing is non-degenerate. + result = *pr.DecompressTorus(pr.MulTorus(t2, t0)) + } else { + // For a product of pairings this might happen when the result is expected to be 1. + // We assign a dummy value (1) to t0 and proceed furhter. + // Finally we do a select on both edge cases: + // - Only if seletor1=0 and selector2=0, we return MulTorus(t2, t0) decompressed. + // - Otherwise, we return 1. + _sum := pr.Ext6.Add(t0, t2) + selector2 = pr.Ext6.IsZero(_sum) + t0 = pr.Ext6.Select(selector2, _dummy, t0) + selector := pr.api.Mul(pr.api.Sub(1, selector1), pr.api.Sub(1, selector2)) + result = *pr.Select(selector, pr.DecompressTorus(pr.MulTorus(t2, t0)), pr.One()) + } + + return &result +} + +// Pair calculates the reduced pairing for a set of points +// ∏ᵢ e(Pᵢ, Qᵢ). +// +// This function doesn't check that the inputs are in the correct subgroups. See AssertIsOnG1 and AssertIsOnG2. +func (pr Pairing) Pair(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + res, err := pr.MillerLoop(P, Q) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.finalExponentiation(res, len(P) == 1) + return res, nil +} + +// PairingCheck calculates the reduced pairing for a set of points and asserts if the result is One +// ∏ᵢ e(Pᵢ, Qᵢ) =? 1 +// +// This function doesn't check that the inputs are in the correct subgroups. See AssertIsOnG1 and AssertIsOnG2. +func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { + f, err := pr.Pair(P, Q) + if err != nil { + return err + + } + one := pr.One() + pr.AssertIsEqual(f, one) + + return nil +} + +func (pr Pairing) AssertIsEqual(x, y *GTEl) { + pr.Ext12.AssertIsEqual(x, y) +} + +func (pr Pairing) AssertIsOnCurve(P *G1Affine) { + pr.curve.AssertIsOnCurve(P) +} + +func (pr Pairing) AssertIsOnTwist(Q *G2Affine) { + // Twist: Y² == X³ + aX + b, where a=0 and b=3/(9+u) + // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) + + // if Q=(0,0) we assign b=0 otherwise 3/(9+u), and continue + selector := pr.api.And(pr.Ext2.IsZero(&Q.X), pr.Ext2.IsZero(&Q.Y)) + b := pr.Ext2.Select(selector, pr.Ext2.Zero(), pr.bTwist) + + left := pr.Ext2.Square(&Q.Y) + right := pr.Ext2.Square(&Q.X) + right = pr.Ext2.Mul(right, &Q.X) + right = pr.Ext2.Add(right, b) + pr.Ext2.AssertIsEqual(left, right) +} + +func (pr Pairing) AssertIsOnG1(P *G1Affine) { + // BN254 has a prime order, so we only + // 1- Check P is on the curve + pr.AssertIsOnCurve(P) +} + +func (pr Pairing) AssertIsOnG2(Q *G2Affine) { + // 1- Check Q is on the curve + pr.AssertIsOnTwist(Q) + + // 2- Check Q has the right subgroup order + + // [x₀]Q + xQ := pr.g2.scalarMulBySeed(Q) + // ψ([x₀]Q) + psixQ := pr.g2.psi(xQ) + // ψ²([x₀]Q) = -ϕ([x₀]Q) + psi2xQ := pr.g2.phi(xQ) + // ψ³([2x₀]Q) + psi3xxQ := pr.g2.double(psi2xQ) + psi3xxQ = pr.g2.psi(psi3xxQ) + + // _Q = ψ³([2x₀]Q) - ψ²([x₀]Q) - ψ([x₀]Q) - [x₀]Q + _Q := pr.g2.sub(psi2xQ, psi3xxQ) + _Q = pr.g2.sub(_Q, psixQ) + _Q = pr.g2.sub(_Q, xQ) + + // [r]Q == 0 <==> _Q == Q + pr.g2.AssertIsEqual(Q, _Q) +} + +// loopCounter = 6x₀+2 = 29793968203157093288 +// +// in 2-NAF +var loopCounter = [66]int8{ + 0, 0, 0, 1, 0, 1, 0, -1, 0, 0, -1, + 0, 0, 0, 1, 0, 0, -1, 0, -1, 0, 0, + 0, 1, 0, -1, 0, 0, 0, 0, -1, 0, 0, + 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, + -1, 0, 0, -1, 0, 1, 0, -1, 0, 0, 0, + -1, 0, -1, 0, 0, 0, 1, 0, -1, 0, 1, +} + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 + R0(x/y) + R1(1/y) = 0 instead of R0'*y + R1'*x + R2' = 0 This +// makes the multiplication by lines (MulBy034) and between lines (Mul034By034) +// circuit-efficient. +type lineEvaluation struct { + R0, R1 fields_bn254.E2 +} + +// MillerLoop computes the multi-Miller loop +// ∏ᵢ { fᵢ_{6x₀+2,Q}(P) · ℓᵢ_{[6x₀+2]Q,π(Q)}(P) · ℓᵢ_{[6x₀+2]Q+π(Q),-π²(Q)}(P) } +func (pr Pairing) MillerLoop(P []*G1Affine, Q []*G2Affine) (*GTEl, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return nil, errors.New("invalid inputs sizes") + } + + res := pr.Ext12.One() + var prodLines [5]fields_bn254.E2 + + var l1, l2 *lineEvaluation + Qacc := make([]*G2Affine, n) + QNeg := make([]*G2Affine, n) + yInv := make([]*emulated.Element[emulated.BN254Fp], n) + xOverY := make([]*emulated.Element[emulated.BN254Fp], n) + + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + QNeg[k] = &G2Affine{X: Q[k].X, Y: *pr.Ext2.Neg(&Q[k].Y)} + // P and Q are supposed to be on G1 and G2 respectively of prime order r. + // The point (x,0) is of order 2. But this function does not check + // subgroup membership. + // Anyway (x,0) cannot be on BN254 because -3 is a cubic non-residue in Fp. + // So, 1/y is well defined for all points P's. + yInv[k] = pr.curveF.Inverse(&P[k].Y) + xOverY[k] = pr.curveF.MulMod(&P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{6x₀+2,Q}(P) } + // i = 64, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line to res) + Qacc[0], l1 = pr.doubleStep(Qacc[0]) + // line evaluation at P[0] + res.C1.B0 = *pr.MulByElement(&l1.R0, xOverY[0]) + res.C1.B1 = *pr.MulByElement(&l1.R1, yInv[0]) + + if n >= 2 { + // k = 1, separately to avoid MulBy034 (res × ℓ) + // (res is also a line at this point, so we use Mul034By034 ℓ × ℓ) + Qacc[1], l1 = pr.doubleStep(Qacc[1]) + + // line evaluation at P[1] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[1]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[1]) + + // ℓ × res + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &res.C1.B0, &res.C1.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B0 = prodLines[3] + res.C1.B1 = prodLines[4] + } + + if n >= 3 { + // k = 2, separately to avoid MulBy034 (res × ℓ) + // (res has a zero E2 element, so we use Mul01234By034) + Qacc[2], l1 = pr.doubleStep(Qacc[2]) + + // line evaluation at P[1] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[2]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[2]) + + // ℓ × res + res = pr.Mul01234By034(&prodLines, &l1.R0, &l1.R1) + + // k >= 3 + for k := 3; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = pr.doubleStep(Qacc[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // ℓ × res + res = pr.MulBy034(res, &l1.R0, &l1.R1) + } + } + + // i = 63, separately to avoid a doubleStep + // (at this point Qacc = 2Q, so 2Qacc-Q=3Q is equivalent to Qacc+Q=3Q + // this means doubleAndAddStep is equivalent to addStep here) + if n == 1 { + res = pr.Square034(res) + + } else { + res = pr.Square(res) + + } + for k := 0; k < n; k++ { + // l2 the line passing Qacc[k] and -Q + l2 = pr.lineCompute(Qacc[k], QNeg[k]) + + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // Qacc[k] ← Qacc[k]+Q[k] and + // l1 the line ℓ passing Qacc[k] and Q[k] + Qacc[k], l1 = pr.addStep(Qacc[k], Q[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + } + + l1s := make([]*lineEvaluation, n) + for i := 62; i >= 0; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + switch loopCounter[i] { + + case 0: + // precompute lines + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1s[k] = pr.doubleStep(Qacc[k]) + + // line evaluation at P[k] + l1s[k].R0 = *pr.MulByElement(&l1s[k].R0, xOverY[k]) + l1s[k].R1 = *pr.MulByElement(&l1s[k].R1, yInv[k]) + + } + + // if number of lines is odd, mul last line by res + // works for n=1 as well + if n%2 != 0 { + // ℓ × res + res = pr.MulBy034(res, &l1s[n-1].R0, &l1s[n-1].R1) + + } + + // mul lines 2-by-2 + for k := 1; k < n; k += 2 { + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1s[k].R0, &l1s[k].R1, &l1s[k-1].R0, &l1s[k-1].R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + case 1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = pr.doubleAndAddStep(Qacc[k], Q[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + case -1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]-Q[k], + // l1 the line ℓ passing Qacc[k] and -Q[k] + // l2 the line ℓ passing (Qacc[k]-Q[k]) and Qacc[k] + Qacc[k], l1, l2 = pr.doubleAndAddStep(Qacc[k], QNeg[k]) + + // line evaluation at P[k] + l1.R0 = *pr.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.MulByElement(&l1.R1, yInv[k]) + + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + default: + return nil, errors.New("invalid loopCounter") + } + } + + // Compute ∏ᵢ { ℓᵢ_{[6x₀+2]Q,π(Q)}(P) · ℓᵢ_{[6x₀+2]Q+π(Q),-π²(Q)}(P) } + Q1, Q2 := new(G2Affine), new(G2Affine) + for k := 0; k < n; k++ { + //Q1 = π(Q) + Q1.X = *pr.Ext2.Conjugate(&Q[k].X) + Q1.X = *pr.Ext2.MulByNonResidue1Power2(&Q1.X) + Q1.Y = *pr.Ext2.Conjugate(&Q[k].Y) + Q1.Y = *pr.Ext2.MulByNonResidue1Power3(&Q1.Y) + + // Q2 = -π²(Q) + Q2.X = *pr.Ext2.MulByNonResidue2Power2(&Q[k].X) + Q2.Y = *pr.Ext2.MulByNonResidue2Power3(&Q[k].Y) + Q2.Y = *pr.Ext2.Neg(&Q2.Y) + + // Qacc[k] ← Qacc[k]+π(Q) and + // l1 the line passing Qacc[k] and π(Q) + Qacc[k], l1 = pr.addStep(Qacc[k], Q1) + + // line evaluation at P[k] + l1.R0 = *pr.Ext2.MulByElement(&l1.R0, xOverY[k]) + l1.R1 = *pr.Ext2.MulByElement(&l1.R1, yInv[k]) + + // l2 the line passing Qacc[k] and -π²(Q) + l2 = pr.lineCompute(Qacc[k], Q2) + // line evaluation at P[k] + l2.R0 = *pr.MulByElement(&l2.R0, xOverY[k]) + l2.R1 = *pr.MulByElement(&l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + + return res, nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleAndAddStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation, *lineEvaluation) { + + var line1, line2 lineEvaluation + var p G2Affine + + // compute λ1 = (y2-y1)/(x2-x1) + n := pr.Ext2.Sub(&p1.Y, &p2.Y) + d := pr.Ext2.Sub(&p1.X, &p2.X) + l1 := pr.Ext2.DivUnchecked(n, d) + + // compute x3 =λ1²-x1-x2 + x3 := pr.Ext2.Square(l1) + x3 = pr.Ext2.Sub(x3, &p1.X) + x3 = pr.Ext2.Sub(x3, &p2.X) + + // omit y3 computation + + // compute line1 + line1.R0 = *pr.Ext2.Neg(l1) + line1.R1 = *pr.Ext2.Mul(l1, &p1.X) + line1.R1 = *pr.Ext2.Sub(&line1.R1, &p1.Y) + + // compute λ2 = -λ1-2y1/(x3-x1) + n = pr.Ext2.Double(&p1.Y) + d = pr.Ext2.Sub(x3, &p1.X) + l2 := pr.Ext2.DivUnchecked(n, d) + l2 = pr.Ext2.Add(l2, l1) + l2 = pr.Ext2.Neg(l2) + + // compute x4 = λ2²-x1-x3 + x4 := pr.Ext2.Square(l2) + x4 = pr.Ext2.Sub(x4, &p1.X) + x4 = pr.Ext2.Sub(x4, x3) + + // compute y4 = λ2(x1 - x4)-y1 + y4 := pr.Ext2.Sub(&p1.X, x4) + y4 = pr.Ext2.Mul(l2, y4) + y4 = pr.Ext2.Sub(y4, &p1.Y) + + p.X = *x4 + p.Y = *y4 + + // compute line2 + line2.R0 = *pr.Ext2.Neg(l2) + line2.R1 = *pr.Ext2.Mul(l2, &p1.X) + line2.R1 = *pr.Ext2.Sub(&line2.R1, &p1.Y) + + return &p, &line1, &line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) doubleStep(p1 *G2Affine) (*G2Affine, *lineEvaluation) { + + var p G2Affine + var line lineEvaluation + + // λ = 3x²/2y + n := pr.Ext2.Square(&p1.X) + three := big.NewInt(3) + n = pr.Ext2.MulByConstElement(n, three) + d := pr.Ext2.Double(&p1.Y) + λ := pr.Ext2.DivUnchecked(n, d) + + // xr = λ²-2x + xr := pr.Ext2.Square(λ) + xr = pr.Ext2.Sub(xr, &p1.X) + xr = pr.Ext2.Sub(xr, &p1.X) + + // yr = λ(x-xr)-y + yr := pr.Ext2.Sub(&p1.X, xr) + yr = pr.Ext2.Mul(λ, yr) + yr = pr.Ext2.Sub(yr, &p1.Y) + + p.X = *xr + p.Y = *yr + + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &p, &line + +} + +// addStep adds two points in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func (pr Pairing) addStep(p1, p2 *G2Affine) (*G2Affine, *lineEvaluation) { + + // compute λ = (y2-y1)/(x2-x1) + p2ypy := pr.Ext2.Sub(&p2.Y, &p1.Y) + p2xpx := pr.Ext2.Sub(&p2.X, &p1.X) + λ := pr.Ext2.DivUnchecked(p2ypy, p2xpx) + + // xr = λ²-x1-x2 + λλ := pr.Ext2.Square(λ) + p2xpx = pr.Ext2.Add(&p1.X, &p2.X) + xr := pr.Ext2.Sub(λλ, p2xpx) + + // yr = λ(x1-xr) - y1 + pxrx := pr.Ext2.Sub(&p1.X, xr) + λpxrx := pr.Ext2.Mul(λ, pxrx) + yr := pr.Ext2.Sub(λpxrx, &p1.Y) + + var res G2Affine + res.X = *xr + res.Y = *yr + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &res, &line + +} + +// lineCompute computes the line that goes through p1 and p2 but does not compute p1+p2 +func (pr Pairing) lineCompute(p1, p2 *G2Affine) *lineEvaluation { + + // compute λ = (y2-y1)/(x2-x1) + qypy := pr.Ext2.Sub(&p2.Y, &p1.Y) + qxpx := pr.Ext2.Sub(&p2.X, &p1.X) + λ := pr.Ext2.DivUnchecked(qypy, qxpx) + + var line lineEvaluation + line.R0 = *pr.Ext2.Neg(λ) + line.R1 = *pr.Ext2.Mul(λ, &p1.X) + line.R1 = *pr.Ext2.Sub(&line.R1, &p1.Y) + + return &line + +} + +// MillerLoopAndMul computes the Miller loop between P and Q +// and multiplies it in 𝔽p¹² by previous. +// +// This method is needed for evmprecompiles/ecpair. +func (pr Pairing) MillerLoopAndMul(P *G1Affine, Q *G2Affine, previous *GTEl) (*GTEl, error) { + res, err := pr.MillerLoop([]*G1Affine{P}, []*G2Affine{Q}) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.Mul(res, previous) + return res, err +} + +// FinalExponentiationIsOne performs the final exponentiation on e +// and checks that the result in 1 in GT. +// +// This method is needed for evmprecompiles/ecpair. +func (pr Pairing) FinalExponentiationIsOne(e *GTEl) { + res := pr.finalExponentiation(e, false) + one := pr.One() + pr.AssertIsEqual(res, one) +} + +// ---------------------------- +// Fixed-argument pairing +// ---------------------------- +// +// The second argument Q is the fixed canonical generator of G2. +// +// Q.X.A0 = 0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed +// Q.X.A1 = 0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2 +// Q.Y.A0 = 0x12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa +// Q.Y.A1 = 0x90689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b + +// MillerLoopFixed computes the single Miller loop +// fᵢ_{u,g2}(P), where g2 is fixed. +func (pr Pairing) MillerLoopFixedQ(P *G1Affine) (*GTEl, error) { + + yInv := pr.curveF.Inverse(&P.Y) + xOverY := pr.curveF.MulMod(&P.X, yInv) + res := pr.Ext12.One() + + // Compute f_{6x₀+2,Q}(P) + // i = 64, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line(P) to res) + res.C1.B0 = *pr.MulByElement(&pr.lines[0][64], xOverY) + res.C1.B1 = *pr.MulByElement(&pr.lines[1][64], yInv) + + // i = 63 + res = pr.Square034(res) + // lines evaluations at P + // and ℓ × ℓ + prodLines := *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][63], xOverY), + pr.MulByElement(&pr.lines[1][63], yInv), + pr.MulByElement(&pr.lines[2][63], xOverY), + pr.MulByElement(&pr.lines[3][63], yInv), + ) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + for i := 62; i >= 0; i-- { + res = pr.Square(res) + + if loopCounter[i] == 0 { + + // line evaluation at P and ℓ × res + res = pr.MulBy034(res, + pr.MulByElement(&pr.lines[0][i], xOverY), + pr.MulByElement(&pr.lines[1][i], yInv), + ) + + } else { + // lines evaluations at P + // and ℓ × ℓ + prodLines := *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][i], xOverY), + pr.MulByElement(&pr.lines[1][i], yInv), + pr.MulByElement(&pr.lines[2][i], xOverY), + pr.MulByElement(&pr.lines[3][i], yInv), + ) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + } + } + + // Compute ℓ_{[6x₀+2]Q,π(Q)}(P) · ℓ_{[6x₀+2]Q+π(Q),-π²(Q)}(P) + // lines evaluations at P + // and ℓ × ℓ + prodLines = *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][65], xOverY), + pr.MulByElement(&pr.lines[1][65], yInv), + pr.MulByElement(&pr.lines[0][66], xOverY), + pr.MulByElement(&pr.lines[1][66], yInv), + ) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + return res, nil +} + +// DoubleMillerLoopFixedQ computes the double Miller loop +// fᵢ_{u,g2}(T) * fᵢ_{u,Q}(P), where g2 is fixed. +func (pr Pairing) DoubleMillerLoopFixedQ(P, T *G1Affine, Q *G2Affine) (*GTEl, error) { + res := pr.Ext12.One() + + var prodLines [5]fields_bn254.E2 + var l1, l2 *lineEvaluation + var Qacc, QNeg *G2Affine + Qacc = Q + QNeg = &G2Affine{X: Q.X, Y: *pr.Ext2.Neg(&Q.Y)} + var yInv, xOverY, y2Inv, x2OverY2 *emulated.Element[emulated.BN254Fp] + yInv = pr.curveF.Inverse(&P.Y) + xOverY = pr.curveF.MulMod(&P.X, yInv) + y2Inv = pr.curveF.Inverse(&T.Y) + x2OverY2 = pr.curveF.MulMod(&T.X, y2Inv) + + // Compute ∏ᵢ { fᵢ_{6x₀+2,Q}(P) } + // i = 64, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // Qacc ← 2Qacc and l1 the tangent ℓ passing 2Qacc + Qacc, l1 = pr.doubleStep(Qacc) + + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + + // precomputed-ℓ × ℓ + prodLines = *pr.Mul034By034( + &l1.R0, + &l1.R1, + pr.MulByElement(&pr.lines[0][64], x2OverY2), + pr.MulByElement(&pr.lines[1][64], y2Inv), + ) + // (precomputed-ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + // i = 63, separately to avoid a doubleStep + // (at this point Qacc = 2Q, so 2Qacc-Q=3Q is equivalent to Qacc+Q=3Q + // this means doubleAndAddStep is equivalent to addStep here) + res = pr.Square(res) + // l2 the line passing Qacc and -Q + l2 = pr.lineCompute(Qacc, QNeg) + + // line evaluation at P + l2.R0 = *pr.MulByElement(&l2.R0, xOverY) + l2.R1 = *pr.MulByElement(&l2.R1, yInv) + + // Qacc ← Qacc+Q and + // l1 the line ℓ passing Qacc and Q + Qacc, l1 = pr.addStep(Qacc, Q) + + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + // precomputed-ℓ × precomputed-ℓ + prodLines = *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][63], x2OverY2), + pr.MulByElement(&pr.lines[1][63], y2Inv), + pr.MulByElement(&pr.lines[2][63], x2OverY2), + pr.MulByElement(&pr.lines[3][63], y2Inv), + ) + // (precomputed-ℓ × precomputed-ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + // Compute ∏ᵢ { fᵢ_{6x₀+2,Q}(P) } + for i := 62; i >= 0; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res = pr.Square(res) + + switch loopCounter[i] { + case 0: + + // Qacc ← 2Qacc and l1 the tangent ℓ passing 2Qacc + Qacc, l1 = pr.doubleStep(Qacc) + + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + + // precomputed-ℓ × ℓ + prodLines = *pr.Mul034By034( + &l1.R0, + &l1.R1, + pr.MulByElement(&pr.lines[0][i], x2OverY2), + pr.MulByElement(&pr.lines[1][i], y2Inv), + ) + // (precomputed-ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + case 1: + // precomputed-ℓ × precomputed-ℓ + prodLines = *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][i], x2OverY2), + pr.MulByElement(&pr.lines[1][i], y2Inv), + pr.MulByElement(&pr.lines[2][i], x2OverY2), + pr.MulByElement(&pr.lines[3][i], y2Inv), + ) + // (precomputed-ℓ × precomputed-ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + // Qacc ← 2Qacc+Q, + // l1 the line ℓ passing Qacc and Q + // l2 the line ℓ passing (Qacc+Q) and Qacc + Qacc, l1, l2 = pr.doubleAndAddStep(Qacc, Q) + + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + + // line evaluation at P + l2.R0 = *pr.MulByElement(&l2.R0, xOverY) + l2.R1 = *pr.MulByElement(&l2.R1, yInv) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + case -1: + // precomputed-ℓ × precomputed-ℓ + prodLines = *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][i], x2OverY2), + pr.MulByElement(&pr.lines[1][i], y2Inv), + pr.MulByElement(&pr.lines[2][i], x2OverY2), + pr.MulByElement(&pr.lines[3][i], y2Inv), + ) + // (precomputed-ℓ × precomputed-ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + // Qacc ← 2Qacc-Q, + // l1 the line ℓ passing Qacc and -Q + // l2 the line ℓ passing (Qacc-Q) and Qacc + Qacc, l1, l2 = pr.doubleAndAddStep(Qacc, QNeg) + + // line evaluation at P + l1.R0 = *pr.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.MulByElement(&l1.R1, yInv) + + // line evaluation at P + l2.R0 = *pr.MulByElement(&l2.R0, xOverY) + l2.R1 = *pr.MulByElement(&l2.R1, yInv) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + default: + return nil, errors.New("invalid loopCounter") + } + } + + // Compute ∏ᵢ { ℓᵢ_{[6x₀+2]Q,π(Q)}(P) · ℓᵢ_{[6x₀+2]Q+π(Q),-π²(Q)}(P) } + Q1, Q2 := new(G2Affine), new(G2Affine) + //Q1 = π(Q) + Q1.X = *pr.Ext2.Conjugate(&Q.X) + Q1.X = *pr.Ext2.MulByNonResidue1Power2(&Q1.X) + Q1.Y = *pr.Ext2.Conjugate(&Q.Y) + Q1.Y = *pr.Ext2.MulByNonResidue1Power3(&Q1.Y) + + // Q2 = -π²(Q) + Q2.X = *pr.Ext2.MulByNonResidue2Power2(&Q.X) + Q2.Y = *pr.Ext2.MulByNonResidue2Power3(&Q.Y) + Q2.Y = *pr.Ext2.Neg(&Q2.Y) + + // Qacc ← Qacc+π(Q) and + // l1 the line passing Qacc and π(Q) + Qacc, l1 = pr.addStep(Qacc, Q1) + + // line evaluation at P + l1.R0 = *pr.Ext2.MulByElement(&l1.R0, xOverY) + l1.R1 = *pr.Ext2.MulByElement(&l1.R1, yInv) + + // l2 the line passing Qacc and -π²(Q) + l2 = pr.lineCompute(Qacc, Q2) + // line evaluation at P + l2.R0 = *pr.MulByElement(&l2.R0, xOverY) + l2.R1 = *pr.MulByElement(&l2.R1, yInv) + + // ℓ × ℓ + prodLines = *pr.Mul034By034(&l1.R0, &l1.R1, &l2.R0, &l2.R1) + // (ℓ × ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + // precomputed-ℓ × precomputed-ℓ + prodLines = *pr.Mul034By034( + pr.MulByElement(&pr.lines[0][65], x2OverY2), + pr.MulByElement(&pr.lines[1][65], y2Inv), + pr.MulByElement(&pr.lines[0][66], x2OverY2), + pr.MulByElement(&pr.lines[1][66], y2Inv), + ) + // (precomputed-ℓ × precomputed-ℓ) × res + res = pr.MulBy01234(res, &prodLines) + + return res, nil +} + +// PairFixedQ calculates the reduced pairing for a set of points +// e(P, g2), where g2 is fixed. +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) PairFixedQ(P *G1Affine) (*GTEl, error) { + res, err := pr.MillerLoopFixedQ(P) + if err != nil { + return nil, fmt.Errorf("miller loop: %w", err) + } + res = pr.finalExponentiation(res, true) + return res, nil +} + +// DoublePairFixedQ calculates the reduced pairing for a set of points +// e(P, Q) * e(T, g2), where g2 is fixed. +// +// This function doesn't check that the inputs are in the correct subgroups. +func (pr Pairing) DoublePairFixedQ(P, T *G1Affine, Q *G2Affine) (*GTEl, error) { + res, err := pr.DoubleMillerLoopFixedQ(P, T, Q) + if err != nil { + return nil, fmt.Errorf("double miller loop: %w", err) + } + res = pr.finalExponentiation(res, false) + return res, nil +} diff --git a/std/algebra/emulated/sw_bn254/pairing_test.go b/std/algebra/emulated/sw_bn254/pairing_test.go new file mode 100644 index 0000000000..632972d950 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/pairing_test.go @@ -0,0 +1,390 @@ +package sw_bn254 + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +func randomG1G2Affines() (bn254.G1Affine, bn254.G2Affine) { + _, _, G1AffGen, G2AffGen := bn254.Generators() + mod := bn254.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + var p bn254.G1Affine + p.ScalarMultiplication(&G1AffGen, s1) + var q bn254.G2Affine + q.ScalarMultiplication(&G2AffGen, s2) + return p, q +} + +type FinalExponentiationCircuit struct { + InGt GTEl + Res GTEl +} + +func (c *FinalExponentiationCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res1 := pairing.FinalExponentiation(&c.InGt) + pairing.AssertIsEqual(res1, &c.Res) + res2 := pairing.FinalExponentiationUnsafe(&c.InGt) + pairing.AssertIsEqual(res2, &c.Res) + return nil +} + +func TestFinalExponentiationTestSolve(t *testing.T) { + assert := test.NewAssert(t) + var gt bn254.GT + gt.SetRandom() + res := bn254.FinalExponentiation(>) + witness := FinalExponentiationCircuit{ + InGt: NewGTEl(gt), + Res: NewGTEl(res), + } + err := test.IsSolved(&FinalExponentiationCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type PairCircuit struct { + InG1 G1Affine + InG2 G2Affine + Res GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + res, err := pairing.Pair([]*G1Affine{&c.InG1}, []*G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + res, err := bn254.Pair([]bn254.G1Affine{p}, []bn254.G2Affine{q}) + assert.NoError(err) + witness := PairCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + Res: NewGTEl(res), + } + err = test.IsSolved(&PairCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type MultiPairCircuit struct { + InG1 G1Affine + InG2 G2Affine + Res GTEl + n int +} + +func (c *MultiPairCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + P, Q := []*G1Affine{}, []*G2Affine{} + for i := 0; i < c.n; i++ { + P = append(P, &c.InG1) + Q = append(Q, &c.InG2) + } + res, err := pairing.Pair(P, Q) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestMultiPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines() + p := make([]bn254.G1Affine, 10) + q := make([]bn254.G2Affine, 10) + for i := 0; i < 10; i++ { + p[i] = p1 + q[i] = q1 + } + + for i := 2; i < 10; i++ { + res, err := bn254.Pair(p[:i], q[:i]) + assert.NoError(err) + witness := MultiPairCircuit{ + InG1: NewG1Affine(p1), + InG2: NewG2Affine(q1), + Res: NewGTEl(res), + } + err = test.IsSolved(&MultiPairCircuit{n: i}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} + +type PairingCheckCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + In2G2 G2Affine +} + +func (c *PairingCheckCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + return nil +} + +func TestPairingCheckTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p1, q1 := randomG1G2Affines() + _, q2 := randomG1G2Affines() + var p2 bn254.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + err := test.IsSolved(&PairingCheckCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type FinalExponentiationSafeCircuit struct { + P1, P2 G1Affine + Q1, Q2 G2Affine +} + +func (c *FinalExponentiationSafeCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return err + } + res, err := pairing.MillerLoop([]*G1Affine{&c.P1, &c.P2}, []*G2Affine{&c.Q1, &c.Q2}) + if err != nil { + return err + } + res2 := pairing.FinalExponentiation(res) + one := pairing.Ext12.One() + pairing.AssertIsEqual(one, res2) + return nil +} + +func TestFinalExponentiationSafeCircuit(t *testing.T) { + assert := test.NewAssert(t) + _, _, p1, q1 := bn254.Generators() + var p2 bn254.G1Affine + var q2 bn254.G2Affine + p2.Neg(&p1) + q2.Set(&q1) + err := test.IsSolved(&FinalExponentiationSafeCircuit{}, &FinalExponentiationSafeCircuit{ + P1: NewG1Affine(p1), + P2: NewG1Affine(p2), + Q1: NewG2Affine(q1), + Q2: NewG2Affine(q2), + }, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type GroupMembershipCircuit struct { + InG1 G1Affine + InG2 G2Affine +} + +func (c *GroupMembershipCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + return nil +} + +func TestGroupMembershipSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + witness := GroupMembershipCircuit{ + InG1: NewG1Affine(p), + InG2: NewG2Affine(q), + } + err := test.IsSolved(&GroupMembershipCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +// ---------------------------- +// Fixed-argument pairing +// ---------------------------- +// +// The second argument Q is the fixed canonical generator of G2. +// +// Q.X.A0 = 0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed +// Q.X.A1 = 0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2 +// Q.Y.A0 = 0x12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa +// Q.Y.A1 = 0x90689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b + +type PairFixedCircuit struct { + InG1 G1Affine + Res GTEl +} + +func (c *PairFixedCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res, err := pairing.PairFixedQ(&c.InG1) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPairFixedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, _ := randomG1G2Affines() + _, _, _, G2AffGen := bn254.Generators() + res, err := bn254.Pair([]bn254.G1Affine{p}, []bn254.G2Affine{G2AffGen}) + assert.NoError(err) + witness := PairFixedCircuit{ + InG1: NewG1Affine(p), + Res: NewGTEl(res), + } + err = test.IsSolved(&PairFixedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type DoublePairFixedCircuit struct { + In1G1 G1Affine + In2G1 G1Affine + In1G2 G2Affine + Res GTEl +} + +func (c *DoublePairFixedCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res, err := pairing.DoublePairFixedQ(&c.In1G1, &c.In2G1, &c.In1G2) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} + +func TestDoublePairFixedTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + _, _, _, G2AffGen := bn254.Generators() + res, err := bn254.Pair([]bn254.G1Affine{p, p}, []bn254.G2Affine{q, G2AffGen}) + assert.NoError(err) + witness := DoublePairFixedCircuit{ + In1G1: NewG1Affine(p), + In2G1: NewG1Affine(p), + In1G2: NewG2Affine(q), + Res: NewGTEl(res), + } + err = test.IsSolved(&DoublePairFixedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +// bench +func BenchmarkPairing(b *testing.B) { + + p1, q1 := randomG1G2Affines() + _, q2 := randomG1G2Affines() + var p2 bn254.G1Affine + p2.Neg(&p1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), + } + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + b.Fatal(err) + } + var ccs constraint.ConstraintSystem + b.Run("compile scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &PairingCheckCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + var buf bytes.Buffer + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("scs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + b.Run("solve scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + b.Run("compile r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &PairingCheckCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + buf.Reset() + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("r1cs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + + b.Run("solve r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/std/algebra/emulated/sw_bn254/precomputations.go b/std/algebra/emulated/sw_bn254/precomputations.go new file mode 100644 index 0000000000..7f9b67a8e3 --- /dev/null +++ b/std/algebra/emulated/sw_bn254/precomputations.go @@ -0,0 +1,450 @@ +package sw_bn254 + +import ( + "sync" + + "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" + "github.com/consensys/gnark/std/math/emulated" +) + +// precomputed lines going through Q and multiples of Q +// where Q is the fixed canonical generator of G2 +// +// Q.X.A0 = 0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed +// Q.X.A1 = 0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2 +// Q.Y.A0 = 0x12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa +// Q.Y.A1 = 0x90689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b +var precomputedLines [4][67]fields_bn254.E2 +var precomputedLinesOnce sync.Once + +func getPrecomputeLines() [4][67]fields_bn254.E2 { + precomputedLinesOnce.Do(func() { + precomputedLines = computePrecomputeLines() + }) + return precomputedLines +} + +func computePrecomputeLines() [4][67]fields_bn254.E2 { + var PrecomputedLines [4][67]fields_bn254.E2 + // i = 64 + PrecomputedLines[0][64].A0 = emulated.ValueOf[emulated.BN254Fp]("5835204804648978854777809389163082959673580093383091483568092875198341589362") + PrecomputedLines[0][64].A1 = emulated.ValueOf[emulated.BN254Fp]("13632706003546654277482391832141703292091762015816023705040318800028245927696") + PrecomputedLines[1][64].A0 = emulated.ValueOf[emulated.BN254Fp]("1680434087217908762188513888731180967069012235541138281753317594838287941133") + PrecomputedLines[1][64].A1 = emulated.ValueOf[emulated.BN254Fp]("19491433686921975987918669077017867748435767299607210121547313497083429353042") + // i = 63 + PrecomputedLines[0][63].A0 = emulated.ValueOf[emulated.BN254Fp]("8834747017950039806917730978057895018652669773221183534396319488771182322273") + PrecomputedLines[0][63].A1 = emulated.ValueOf[emulated.BN254Fp]("20569453214085543698303175670835565927230899674712780610087152439201543453755") + PrecomputedLines[1][63].A0 = emulated.ValueOf[emulated.BN254Fp]("12474451462113811170279691739806155555255035988412241892389371812852943279791") + PrecomputedLines[1][63].A1 = emulated.ValueOf[emulated.BN254Fp]("11583683749753447484324963631355529779732614390183595665835388782603973114554") + PrecomputedLines[2][63].A0 = emulated.ValueOf[emulated.BN254Fp]("4351619662097199247792407486841887892728868845867619468460367957111511777281") + PrecomputedLines[2][63].A1 = emulated.ValueOf[emulated.BN254Fp]("5684094267725805546491764759679865993449104797108069757958938517705465746853") + PrecomputedLines[3][63].A0 = emulated.ValueOf[emulated.BN254Fp]("10353962084714942711958392698892131194533550945227135631964812972535754938740") + PrecomputedLines[3][63].A1 = emulated.ValueOf[emulated.BN254Fp]("1928175489709988399997528177153275473647160619426141896807527496697530478537") + // i = 62 + PrecomputedLines[0][62].A0 = emulated.ValueOf[emulated.BN254Fp]("8235221535982217724508088798625418488934211808403480937732051788034499272591") + PrecomputedLines[0][62].A1 = emulated.ValueOf[emulated.BN254Fp]("21624988872631589384985418538844721483027269339384612878942921292555079826013") + PrecomputedLines[1][62].A0 = emulated.ValueOf[emulated.BN254Fp]("20348378478608120338812639349974340732102235507506642310441900724921087826247") + PrecomputedLines[1][62].A1 = emulated.ValueOf[emulated.BN254Fp]("8566792459109340182179639521038613097990900893660026554466166360331455617193") + // i = 61 + PrecomputedLines[0][61].A0 = emulated.ValueOf[emulated.BN254Fp]("6627719691581136027519842318508717081939451729821280503007308537592603220412") + PrecomputedLines[0][61].A1 = emulated.ValueOf[emulated.BN254Fp]("11524325490035505497336553724692563653964162500918217938886754632362466082811") + PrecomputedLines[1][61].A0 = emulated.ValueOf[emulated.BN254Fp]("19731985605225575090697166005028587616606038298772396496599168985556363200631") + PrecomputedLines[1][61].A1 = emulated.ValueOf[emulated.BN254Fp]("7373556318285282840971580083079025136539854915977927179933078831491469609682") + PrecomputedLines[2][61].A0 = emulated.ValueOf[emulated.BN254Fp]("13733440734738072597076384372509888040506082931566095722592032318509288104124") + PrecomputedLines[2][61].A1 = emulated.ValueOf[emulated.BN254Fp]("14721209502407905805334137781178425261132829203821641437369233014505902590526") + PrecomputedLines[3][61].A0 = emulated.ValueOf[emulated.BN254Fp]("5027080457503687104862577903377368485365806076692125497436352950815222879179") + PrecomputedLines[3][61].A1 = emulated.ValueOf[emulated.BN254Fp]("19289658640175986155793149849769305065167539432896819573858459105348955554557") + // i = 60 + PrecomputedLines[0][60].A0 = emulated.ValueOf[emulated.BN254Fp]("14632988473650232706638308445044627656823840762055592679861976480201049786937") + PrecomputedLines[0][60].A1 = emulated.ValueOf[emulated.BN254Fp]("18477931370497242185852032413411380760829755240315614825130799062365486125783") + PrecomputedLines[1][60].A0 = emulated.ValueOf[emulated.BN254Fp]("18049101236999327864886811068212417971066320579871553559963860415669018134496") + PrecomputedLines[1][60].A1 = emulated.ValueOf[emulated.BN254Fp]("3340541398203178600550723254278705350179917842084176758931787480715875388093") + // i = 59 + PrecomputedLines[0][59].A0 = emulated.ValueOf[emulated.BN254Fp]("17045135689595429000496650894177684583849536764983720634964441404516121899392") + PrecomputedLines[0][59].A1 = emulated.ValueOf[emulated.BN254Fp]("20874559392346445015406487528270220685966388146273495510982431769773849348197") + PrecomputedLines[1][59].A0 = emulated.ValueOf[emulated.BN254Fp]("8223740335264218373639292193525349753225213618348104360137437003379775025437") + PrecomputedLines[1][59].A1 = emulated.ValueOf[emulated.BN254Fp]("2548870952786184681128163676627501278690702287544993422150040172278516478745") + // i = 58 + PrecomputedLines[0][58].A0 = emulated.ValueOf[emulated.BN254Fp]("10954828682858274260092126382718192953262472685421314903371241946559589624873") + PrecomputedLines[0][58].A1 = emulated.ValueOf[emulated.BN254Fp]("8743995751245898721379411778296781208636410240371573947003929850344297435525") + PrecomputedLines[1][58].A0 = emulated.ValueOf[emulated.BN254Fp]("17024182879676720943763727838881486298580583134896607657788582772952988629132") + PrecomputedLines[1][58].A1 = emulated.ValueOf[emulated.BN254Fp]("233080409219735443943562019109568508823238327587973458579303286089303864680") + // i = 57 + PrecomputedLines[0][57].A0 = emulated.ValueOf[emulated.BN254Fp]("12410140729570783839406161286916119301565534084495020370980905806792203460134") + PrecomputedLines[0][57].A1 = emulated.ValueOf[emulated.BN254Fp]("14094884930234597736770350376505495314765932446313713250354944203186636575455") + PrecomputedLines[1][57].A0 = emulated.ValueOf[emulated.BN254Fp]("532840589523723053112594079794333892977065522820117356517422849664750230778") + PrecomputedLines[1][57].A1 = emulated.ValueOf[emulated.BN254Fp]("5660638474743049541851028697096356971110017724650313911198424181549845543920") + PrecomputedLines[2][57].A0 = emulated.ValueOf[emulated.BN254Fp]("18032936869020334689349219157241011538136257425003297594369894001691209779164") + PrecomputedLines[2][57].A1 = emulated.ValueOf[emulated.BN254Fp]("16754765645720625866506074952008486046920919307557056709799845067180351678981") + PrecomputedLines[3][57].A0 = emulated.ValueOf[emulated.BN254Fp]("18040737120690367958636522976877072351597960312418873862053894642308039267165") + PrecomputedLines[3][57].A1 = emulated.ValueOf[emulated.BN254Fp]("12847080167727762220604062477143165139413815667631144830150678996896243094397") + // i = 56 + PrecomputedLines[0][56].A0 = emulated.ValueOf[emulated.BN254Fp]("19645311236785860275323414293804502000247542616913482529314177973392964070041") + PrecomputedLines[0][56].A1 = emulated.ValueOf[emulated.BN254Fp]("11127380619767050390611672259834755402156496596006857437940271701361623250389") + PrecomputedLines[1][56].A0 = emulated.ValueOf[emulated.BN254Fp]("12866789663733235640870663615630153955969739486663450312836464903244952796759") + PrecomputedLines[1][56].A1 = emulated.ValueOf[emulated.BN254Fp]("20334247261155778215897034951123988527365486037964674499287611566337802994246") + // i = 55 + PrecomputedLines[0][55].A0 = emulated.ValueOf[emulated.BN254Fp]("4909602325743718030494127948247513676144699561762181851595755686145634562165") + PrecomputedLines[0][55].A1 = emulated.ValueOf[emulated.BN254Fp]("4980795661388523093721831770060089904711582345302707214936616089440639723701") + PrecomputedLines[1][55].A0 = emulated.ValueOf[emulated.BN254Fp]("15971329859948607389718743711907428875739552759192560895900443046558268384833") + PrecomputedLines[1][55].A1 = emulated.ValueOf[emulated.BN254Fp]("6865348535803696642084936610119046588478673910965860309150729592421280179161") + PrecomputedLines[2][55].A0 = emulated.ValueOf[emulated.BN254Fp]("2764448302387997018686838945031167114772677413806369496295142807496601012389") + PrecomputedLines[2][55].A1 = emulated.ValueOf[emulated.BN254Fp]("11531890680898353878360012539400900008526988690934761520248628457335690795123") + PrecomputedLines[3][55].A0 = emulated.ValueOf[emulated.BN254Fp]("20477902323488575991951213865423297622833142796361576949400620096696322938714") + PrecomputedLines[3][55].A1 = emulated.ValueOf[emulated.BN254Fp]("15588298079216618665269525070947878050724410941478388548632206918344265255660") + // i = 54 + PrecomputedLines[0][54].A0 = emulated.ValueOf[emulated.BN254Fp]("2423954684975409421113106111483795343402357285046083347957019460241561263903") + PrecomputedLines[0][54].A1 = emulated.ValueOf[emulated.BN254Fp]("21686350663615057943513793308336218547546151034882551854568393652361503247048") + PrecomputedLines[1][54].A0 = emulated.ValueOf[emulated.BN254Fp]("17983889902010442989037806626756227210350021563685842847697610334754612024861") + PrecomputedLines[1][54].A1 = emulated.ValueOf[emulated.BN254Fp]("16240742027186297373450412352773641086393171254541832522856010008961570443994") + // i = 53 + PrecomputedLines[0][53].A0 = emulated.ValueOf[emulated.BN254Fp]("3299397589651430835046889418728383096966563010050347574588126659195099410858") + PrecomputedLines[0][53].A1 = emulated.ValueOf[emulated.BN254Fp]("2702370610853020400314647094191315939694996588064902553275009984508683489928") + PrecomputedLines[1][53].A0 = emulated.ValueOf[emulated.BN254Fp]("3979157981082211890213327198809251890345860044112490020571189248429346769401") + PrecomputedLines[1][53].A1 = emulated.ValueOf[emulated.BN254Fp]("19944567808172014331943455052193835769131894893228158061444755101571479200005") + // i = 52 + PrecomputedLines[0][52].A0 = emulated.ValueOf[emulated.BN254Fp]("6258196544160211102321925954410279782816574037374433438978246710964080754940") + PrecomputedLines[0][52].A1 = emulated.ValueOf[emulated.BN254Fp]("3005660902529334848218755993004858578785696325272369744839253118680617897204") + PrecomputedLines[1][52].A0 = emulated.ValueOf[emulated.BN254Fp]("814264396833942665903531553752517163857204609821495556934616064593387497987") + PrecomputedLines[1][52].A1 = emulated.ValueOf[emulated.BN254Fp]("18008736309258122452427755394070359806936368847348640156095452637685946009270") + // i = 51 + PrecomputedLines[0][51].A0 = emulated.ValueOf[emulated.BN254Fp]("12420890053219508598807591154496379602809478224965227087113822934299408691873") + PrecomputedLines[0][51].A1 = emulated.ValueOf[emulated.BN254Fp]("17303367877986414511475092300779955134419370651210868516114472535617760078437") + PrecomputedLines[1][51].A0 = emulated.ValueOf[emulated.BN254Fp]("5654129892575944946690000888934348891080307733028900301840391685691529358356") + PrecomputedLines[1][51].A1 = emulated.ValueOf[emulated.BN254Fp]("16233036145896569525486180974188009724711185398467310790001445624076382127411") + PrecomputedLines[2][51].A0 = emulated.ValueOf[emulated.BN254Fp]("15280789913398818336264656867575336658350410371574894768730045385395713509640") + PrecomputedLines[2][51].A1 = emulated.ValueOf[emulated.BN254Fp]("19086876210297108145367386096540798080237895501217269732090299553431756186874") + PrecomputedLines[3][51].A0 = emulated.ValueOf[emulated.BN254Fp]("20152602455257424885705855540594687474352896500023277194652697233085988612702") + PrecomputedLines[3][51].A1 = emulated.ValueOf[emulated.BN254Fp]("5047067997772839341060798895369969865071494197407643998682395775472163143236") + // i = 50 + PrecomputedLines[0][50].A0 = emulated.ValueOf[emulated.BN254Fp]("9679722372596923253224153788895318885991703645078795262894298224910585336719") + PrecomputedLines[0][50].A1 = emulated.ValueOf[emulated.BN254Fp]("14514223747309844563686332289086628131330780550209469152711732915738910690017") + PrecomputedLines[1][50].A0 = emulated.ValueOf[emulated.BN254Fp]("21089781100202322623793764397713628697432828784246141248094996360041604110331") + PrecomputedLines[1][50].A1 = emulated.ValueOf[emulated.BN254Fp]("17528189601039012041588006220155360374406334568305836940619703484971862864495") + // i = 49 + PrecomputedLines[0][49].A0 = emulated.ValueOf[emulated.BN254Fp]("11070264611763304059093376142120848903484793346053636108950365869037426569578") + PrecomputedLines[0][49].A1 = emulated.ValueOf[emulated.BN254Fp]("5566620684243985148174385257832990705507517654546203571578376077206518542096") + PrecomputedLines[1][49].A0 = emulated.ValueOf[emulated.BN254Fp]("21439244283016676608324861890592606819171700780439376409220912790618219709100") + PrecomputedLines[1][49].A1 = emulated.ValueOf[emulated.BN254Fp]("7429886405799165109635340226052960810301612425806339795255948228548720117114") + PrecomputedLines[2][49].A0 = emulated.ValueOf[emulated.BN254Fp]("15262812223739557968076855432738368535216463506023656398166408211596744345480") + PrecomputedLines[2][49].A1 = emulated.ValueOf[emulated.BN254Fp]("21718155741215205740714059098811649981889099801055042894213824492357370586441") + PrecomputedLines[3][49].A0 = emulated.ValueOf[emulated.BN254Fp]("18117071879837371740798839911980038457907852615599242665408622350971595899012") + PrecomputedLines[3][49].A1 = emulated.ValueOf[emulated.BN254Fp]("8568859093111261946081380679772818415299601251733724245998939644866684988958") + // i = 48 + PrecomputedLines[0][48].A0 = emulated.ValueOf[emulated.BN254Fp]("7653502376183972189528005574912937698901707992597867413732313005949090861078") + PrecomputedLines[0][48].A1 = emulated.ValueOf[emulated.BN254Fp]("18078299108818183297076160903115040870971029681103562344505507059920109498724") + PrecomputedLines[1][48].A0 = emulated.ValueOf[emulated.BN254Fp]("5958828994437554107496173026037633753578905513469664500371044380353413307215") + PrecomputedLines[1][48].A1 = emulated.ValueOf[emulated.BN254Fp]("4457857287549567054803825263261489919553034648299704807245359432847888563495") + // i = 47 + PrecomputedLines[0][47].A0 = emulated.ValueOf[emulated.BN254Fp]("11785245607980003005839007342283447681532074921512117922377232062695696875063") + PrecomputedLines[0][47].A1 = emulated.ValueOf[emulated.BN254Fp]("3674471715303242650639836240573952653116010918695452109397036698752292897898") + PrecomputedLines[1][47].A0 = emulated.ValueOf[emulated.BN254Fp]("45627793441629135917601349379085503635034794546470734893491110198522468332") + PrecomputedLines[1][47].A1 = emulated.ValueOf[emulated.BN254Fp]("13592279364712648679231730257313513790926888930629759987134996708519290791723") + PrecomputedLines[2][47].A0 = emulated.ValueOf[emulated.BN254Fp]("4219137964241005960272202013946114675691474544442298976511002870769419504812") + PrecomputedLines[2][47].A1 = emulated.ValueOf[emulated.BN254Fp]("14310914677517468679944086988957529739710673981285059899394344618236313440099") + PrecomputedLines[3][47].A0 = emulated.ValueOf[emulated.BN254Fp]("10374119938212414799339884240754452606245790388879855930781662754854289264542") + PrecomputedLines[3][47].A1 = emulated.ValueOf[emulated.BN254Fp]("7249803302968745898565931551390491553765429802954021150194461040817468653516") + // i = 46 + PrecomputedLines[0][46].A0 = emulated.ValueOf[emulated.BN254Fp]("3322469610693328663691825947323331441415625540966696120651002654440253319052") + PrecomputedLines[0][46].A1 = emulated.ValueOf[emulated.BN254Fp]("1942780532870155484021476731037192630204446135294484553429490820308706502574") + PrecomputedLines[1][46].A0 = emulated.ValueOf[emulated.BN254Fp]("17202876709489003816438966092823149993355061110971733076837388178975675061543") + PrecomputedLines[1][46].A1 = emulated.ValueOf[emulated.BN254Fp]("1827707550130567292619674424799209527730638026770607682353371933371424896649") + // i = 45 + PrecomputedLines[0][45].A0 = emulated.ValueOf[emulated.BN254Fp]("4454860302834193997922027643677141826335477685905943160881719665216396316023") + PrecomputedLines[0][45].A1 = emulated.ValueOf[emulated.BN254Fp]("13460420380135308103263096461528931372790583621207990906826840013420510427402") + PrecomputedLines[1][45].A0 = emulated.ValueOf[emulated.BN254Fp]("6173471785502118124181307330287807246726679438042165280456980585789204024261") + PrecomputedLines[1][45].A1 = emulated.ValueOf[emulated.BN254Fp]("9115487602626298454909684608831879314662251126849935581628519764110399653742") + // i = 44 + PrecomputedLines[0][44].A0 = emulated.ValueOf[emulated.BN254Fp]("1808160525575058758917825482888189988480854443063379197012638697518523206015") + PrecomputedLines[0][44].A1 = emulated.ValueOf[emulated.BN254Fp]("17974039573517006177103900772046815180642726119197145299789772931932242373761") + PrecomputedLines[1][44].A0 = emulated.ValueOf[emulated.BN254Fp]("4836565830302726743825617174678612813412946168360212251833373381562334895992") + PrecomputedLines[1][44].A1 = emulated.ValueOf[emulated.BN254Fp]("6540127829396182735905780776540077847286814503740931399470223558889379807318") + PrecomputedLines[2][44].A0 = emulated.ValueOf[emulated.BN254Fp]("14217246290721035597535922596218461586467355332509420181286410302541613841694") + PrecomputedLines[2][44].A1 = emulated.ValueOf[emulated.BN254Fp]("2908093977878325873797969441052414806114066218549559765693286267809031453390") + PrecomputedLines[3][44].A0 = emulated.ValueOf[emulated.BN254Fp]("11160891947094503976396935083892652810330554397537039864076674697178398903810") + PrecomputedLines[3][44].A1 = emulated.ValueOf[emulated.BN254Fp]("8932087540233051702344716353214580354819556708224293257744342595927104649379") + // i = 43 + PrecomputedLines[0][43].A0 = emulated.ValueOf[emulated.BN254Fp]("21406534929501956719171835897604615263148474051911745550006151022849957388643") + PrecomputedLines[0][43].A1 = emulated.ValueOf[emulated.BN254Fp]("10279781929005127439140488793928581246508457419866317770606378400774241150772") + PrecomputedLines[1][43].A0 = emulated.ValueOf[emulated.BN254Fp]("9709078210642965619422714825761003836985439131762449854224392441007339412834") + PrecomputedLines[1][43].A1 = emulated.ValueOf[emulated.BN254Fp]("11129005740706905660803848575047973153280590126228714832868068527425037373751") + // i = 42 + PrecomputedLines[0][42].A0 = emulated.ValueOf[emulated.BN254Fp]("5570866902587445338732530591687200111549728761628124016050692968437706751601") + PrecomputedLines[0][42].A1 = emulated.ValueOf[emulated.BN254Fp]("10275848058977370789435005534809927911069178479505311836113062825745267179381") + PrecomputedLines[1][42].A0 = emulated.ValueOf[emulated.BN254Fp]("3633887160270413594942060642216918996537999582294210057554627044590039786234") + PrecomputedLines[1][42].A1 = emulated.ValueOf[emulated.BN254Fp]("8939668248755020860948171356447789359992139917425514662320694855157405768150") + // i = 41 + PrecomputedLines[0][41].A0 = emulated.ValueOf[emulated.BN254Fp]("18885470138073884864482120331203309626511755319182411648838895487000802451896") + PrecomputedLines[0][41].A1 = emulated.ValueOf[emulated.BN254Fp]("9622125734920158797843141885788251658263242254352775708729064763529753374409") + PrecomputedLines[1][41].A0 = emulated.ValueOf[emulated.BN254Fp]("19492892729379236004557729196741558779480783232265355978008529229208638077001") + PrecomputedLines[1][41].A1 = emulated.ValueOf[emulated.BN254Fp]("9026667103760964167291393581811590447916219414798161581254060489069424386713") + // i = 40 + PrecomputedLines[0][40].A0 = emulated.ValueOf[emulated.BN254Fp]("7917086886314173678014982140498314563333885917374381676048136849306589970300") + PrecomputedLines[0][40].A1 = emulated.ValueOf[emulated.BN254Fp]("12261633934171134354022033316006453963392366695723701109381951932703857762414") + PrecomputedLines[1][40].A0 = emulated.ValueOf[emulated.BN254Fp]("10914160321677904051918987802671845728099932842880059488686382944858229492896") + PrecomputedLines[1][40].A1 = emulated.ValueOf[emulated.BN254Fp]("8359557585977416078793909793984425896635087139789596461899708508514130896302") + // i = 39 + PrecomputedLines[0][39].A0 = emulated.ValueOf[emulated.BN254Fp]("2832877770969085666300187729756811975718245115828438742107798110254443365408") + PrecomputedLines[0][39].A1 = emulated.ValueOf[emulated.BN254Fp]("4015898436582302751338662177471784429329474816900569691956338217668304075143") + PrecomputedLines[1][39].A0 = emulated.ValueOf[emulated.BN254Fp]("6842872200246861427705093037496466614015446887756834045870833833876444837543") + PrecomputedLines[1][39].A1 = emulated.ValueOf[emulated.BN254Fp]("15273790836304547702356985139791848656234288403812600849950908226829046944652") + // i = 38 + PrecomputedLines[0][38].A0 = emulated.ValueOf[emulated.BN254Fp]("17748964682733066496477674137304072265649362699626972384780286384184403231976") + PrecomputedLines[0][38].A1 = emulated.ValueOf[emulated.BN254Fp]("2838655568440449121815178130424124290652171075642869140451330756670997612885") + PrecomputedLines[1][38].A0 = emulated.ValueOf[emulated.BN254Fp]("1913181888602803253410490851322440164398255901127376824362120108174814218344") + PrecomputedLines[1][38].A1 = emulated.ValueOf[emulated.BN254Fp]("17740518402110173452265054327145725313649280334580334996118883085307288091170") + PrecomputedLines[2][38].A0 = emulated.ValueOf[emulated.BN254Fp]("12808369270471380551462058943009794572165562282177042824025566058227667838719") + PrecomputedLines[2][38].A1 = emulated.ValueOf[emulated.BN254Fp]("13850937935385199095549631062956009887187878536528675610032140109569037386090") + PrecomputedLines[3][38].A0 = emulated.ValueOf[emulated.BN254Fp]("9571531692086612395026521386471462664121950697416902123408435110246350832630") + PrecomputedLines[3][38].A1 = emulated.ValueOf[emulated.BN254Fp]("9262658991286591312944640932012648892614883797518222287840538936149182139125") + // i = 37 + PrecomputedLines[0][37].A0 = emulated.ValueOf[emulated.BN254Fp]("12457672116858821147570417815713622479940621522732096986429050939590597339051") + PrecomputedLines[0][37].A1 = emulated.ValueOf[emulated.BN254Fp]("18398985605215365318947837405711002640755060748945931179484765291655005193783") + PrecomputedLines[1][37].A0 = emulated.ValueOf[emulated.BN254Fp]("6761100916752031499470061472959664103643695753371128647767956417871801525935") + PrecomputedLines[1][37].A1 = emulated.ValueOf[emulated.BN254Fp]("11516568990599057505448657013624491615577778080330416531299339213631684799703") + // i = 36 + PrecomputedLines[0][36].A0 = emulated.ValueOf[emulated.BN254Fp]("20157739938048499791825110639183940868502246006091528387643069694176344051024") + PrecomputedLines[0][36].A1 = emulated.ValueOf[emulated.BN254Fp]("2828679024172680272609250174898575639278853817285186536953499674429263641394") + PrecomputedLines[1][36].A0 = emulated.ValueOf[emulated.BN254Fp]("5507629493376754776780948128794176380876913178566429119195698066509123857842") + PrecomputedLines[1][36].A1 = emulated.ValueOf[emulated.BN254Fp]("11989866298421717405888814591756712355645135873368124241103659461001231130026") + // i = 35 + PrecomputedLines[0][35].A0 = emulated.ValueOf[emulated.BN254Fp]("18803362902571183160534990574868966535501222775740632166562002643682331204055") + PrecomputedLines[0][35].A1 = emulated.ValueOf[emulated.BN254Fp]("2233847910285387579556930418866824465027863750525494540604564132801222851847") + PrecomputedLines[1][35].A0 = emulated.ValueOf[emulated.BN254Fp]("14894594813989329665437337634052128835849772709222460160323937655662277269998") + PrecomputedLines[1][35].A1 = emulated.ValueOf[emulated.BN254Fp]("16401863794856018538890651330710306550150048203763791766885991058365625250312") + PrecomputedLines[2][35].A0 = emulated.ValueOf[emulated.BN254Fp]("13863001468971850689773930443498106621781607351046321740833110320784300885462") + PrecomputedLines[2][35].A1 = emulated.ValueOf[emulated.BN254Fp]("12531482429379825647834106004615712608548514329644694776489888393706881261235") + PrecomputedLines[3][35].A0 = emulated.ValueOf[emulated.BN254Fp]("9873979880229408105971223367529128014487901861772253050912642581888257648233") + PrecomputedLines[3][35].A1 = emulated.ValueOf[emulated.BN254Fp]("17024384825828294129839942917357501284449447510471307237753297102647863989056") + // i = 34 + PrecomputedLines[0][34].A0 = emulated.ValueOf[emulated.BN254Fp]("9653366074145323295722703240755007396582576595924214989534987903378311878297") + PrecomputedLines[0][34].A1 = emulated.ValueOf[emulated.BN254Fp]("20335540665422011227662851543501012053614149159296868563872809593806311578153") + PrecomputedLines[1][34].A0 = emulated.ValueOf[emulated.BN254Fp]("3346681777791788671775024643719873028656163698396937227379634931654104469710") + PrecomputedLines[1][34].A1 = emulated.ValueOf[emulated.BN254Fp]("3307435597320603699014389476882124244457440216712858513447652502717610656034") + // i = 33 + PrecomputedLines[0][33].A0 = emulated.ValueOf[emulated.BN254Fp]("18209443688907178902666130371838848829021303314723242040408374329257505874298") + PrecomputedLines[0][33].A1 = emulated.ValueOf[emulated.BN254Fp]("17012865229030954811928627442171265444826656974298650920955602237427025269570") + PrecomputedLines[1][33].A0 = emulated.ValueOf[emulated.BN254Fp]("13251168487096931816968492472936283913931508108412049848483510193471001761262") + PrecomputedLines[1][33].A1 = emulated.ValueOf[emulated.BN254Fp]("917662416056698492181621755862599484479280610762074595968390001996439032523") + PrecomputedLines[2][33].A0 = emulated.ValueOf[emulated.BN254Fp]("2780404211823942337646698841772801564153614837266356530930914670418264278722") + PrecomputedLines[2][33].A1 = emulated.ValueOf[emulated.BN254Fp]("3782019545450857785432418735694615956576199844498504213150631388689924073378") + PrecomputedLines[3][33].A0 = emulated.ValueOf[emulated.BN254Fp]("10219383385884950870363850119493120482938585228366425933545069183698860959545") + PrecomputedLines[3][33].A1 = emulated.ValueOf[emulated.BN254Fp]("4610898162835928978462538221841591782260443101663343291474718241976139555925") + // i = 32 + PrecomputedLines[0][32].A0 = emulated.ValueOf[emulated.BN254Fp]("20934811687581609260873070602688515368109046845218194607210991416167897816402") + PrecomputedLines[0][32].A1 = emulated.ValueOf[emulated.BN254Fp]("5862163452113823219899111570982726953048159166456119375917149119349465378872") + PrecomputedLines[1][32].A0 = emulated.ValueOf[emulated.BN254Fp]("18680771072947791515144826030430400508598943175893875281965984728348827258637") + PrecomputedLines[1][32].A1 = emulated.ValueOf[emulated.BN254Fp]("21040117418693567065352610126251750729568112020473406293397044779880248456691") + // i = 31 + PrecomputedLines[0][31].A0 = emulated.ValueOf[emulated.BN254Fp]("7964439593062737400984035539878664198651178025806860534579667949165595587005") + PrecomputedLines[0][31].A1 = emulated.ValueOf[emulated.BN254Fp]("2838157160991125797583478331796928440511613965670870614157463438056333538151") + PrecomputedLines[1][31].A0 = emulated.ValueOf[emulated.BN254Fp]("15229006183665644290447325495913132772531638749927123822183273630764466089789") + PrecomputedLines[1][31].A1 = emulated.ValueOf[emulated.BN254Fp]("16074291619482614777639379462192763167638195517822265884849973509949484533348") + // i = 30 + PrecomputedLines[0][30].A0 = emulated.ValueOf[emulated.BN254Fp]("17216673940334785289516808707485909812399730071261828078641261494030193339151") + PrecomputedLines[0][30].A1 = emulated.ValueOf[emulated.BN254Fp]("1558246640723974010862813923593922625293354601470135041696925666565577625238") + PrecomputedLines[1][30].A0 = emulated.ValueOf[emulated.BN254Fp]("3854786674985460796726122882377437532974958505502673869156016507727783114650") + PrecomputedLines[1][30].A1 = emulated.ValueOf[emulated.BN254Fp]("16797162886036336980590366085149259027597156341494947863166011673482030283822") + PrecomputedLines[2][30].A0 = emulated.ValueOf[emulated.BN254Fp]("18793648319122864915144966308137999498287217564919580214223847976324114889072") + PrecomputedLines[2][30].A1 = emulated.ValueOf[emulated.BN254Fp]("16864911664763277142359301452971910091257378264425969043189691884660256976414") + PrecomputedLines[3][30].A0 = emulated.ValueOf[emulated.BN254Fp]("6152358889597859795339082960365216775296352387558616966143609899597537421031") + PrecomputedLines[3][30].A1 = emulated.ValueOf[emulated.BN254Fp]("16653254093932107167798383542777416165270049087427470272658796467120270028484") + // i = 29 + PrecomputedLines[0][29].A0 = emulated.ValueOf[emulated.BN254Fp]("14632407693471058889728140077691609987771263933332165632341567611801065809022") + PrecomputedLines[0][29].A1 = emulated.ValueOf[emulated.BN254Fp]("11272021435858583497469551493401241107077029606274078664163635970617420849248") + PrecomputedLines[1][29].A0 = emulated.ValueOf[emulated.BN254Fp]("4770200410992146769475852516775361139844234572488330099462736062550552604829") + PrecomputedLines[1][29].A1 = emulated.ValueOf[emulated.BN254Fp]("7530403448372346045695321765829491479483699869272077605645557759889158081085") + // i = 28 + PrecomputedLines[0][28].A0 = emulated.ValueOf[emulated.BN254Fp]("11519923879341878511380494246354517413557387093469380775615805559630889395544") + PrecomputedLines[0][28].A1 = emulated.ValueOf[emulated.BN254Fp]("1407583269883799199470286256592915474697530621627456286943142767826758460634") + PrecomputedLines[1][28].A0 = emulated.ValueOf[emulated.BN254Fp]("11904307076773678983926013243694306017809942930429096021607731892636261237434") + PrecomputedLines[1][28].A1 = emulated.ValueOf[emulated.BN254Fp]("5418228693521695819333327920054742553191693728075186406978931357199320286728") + // i = 27 + PrecomputedLines[0][27].A0 = emulated.ValueOf[emulated.BN254Fp]("13583649344163235628273822059455216086651158972357128597527291053634930639606") + PrecomputedLines[0][27].A1 = emulated.ValueOf[emulated.BN254Fp]("1580303658246025496517878831829049949728218261141271124320904477564504770538") + PrecomputedLines[1][27].A0 = emulated.ValueOf[emulated.BN254Fp]("21116413106194933456702972814632992147195390921459561292248712868055344887668") + PrecomputedLines[1][27].A1 = emulated.ValueOf[emulated.BN254Fp]("2722273224111047596007628929518646734380443821631939025258952090508660165266") + // i = 26 + PrecomputedLines[0][26].A0 = emulated.ValueOf[emulated.BN254Fp]("18471407238975412722919516576273263250888565110885782624375161300902906625257") + PrecomputedLines[0][26].A1 = emulated.ValueOf[emulated.BN254Fp]("1703792855453408300033031347710143415065900804014322538641022938745091381930") + PrecomputedLines[1][26].A0 = emulated.ValueOf[emulated.BN254Fp]("10863470541316497325728428920551887981270488659171371431262418928646119718147") + PrecomputedLines[1][26].A1 = emulated.ValueOf[emulated.BN254Fp]("17241260064348476546268798180555052578472388987371855915353518730015387072298") + // i = 25 + PrecomputedLines[0][25].A0 = emulated.ValueOf[emulated.BN254Fp]("8932519706584413797822306415256639256040689258169061374436911262999997993142") + PrecomputedLines[0][25].A1 = emulated.ValueOf[emulated.BN254Fp]("393522374513586188088711426908867437661499514838744902729423194492781594674") + PrecomputedLines[1][25].A0 = emulated.ValueOf[emulated.BN254Fp]("16189841300225557029339625228091337040532368631861952893841567332696428948407") + PrecomputedLines[1][25].A1 = emulated.ValueOf[emulated.BN254Fp]("1478939903044857860077184972935680812866817631078410725667566430070755143644") + PrecomputedLines[2][25].A0 = emulated.ValueOf[emulated.BN254Fp]("11577786770240021634241162518303444432845727667924972478311944308780176883448") + PrecomputedLines[2][25].A1 = emulated.ValueOf[emulated.BN254Fp]("12996576065238071442557753878581565019438532181136511892084803477505718680757") + PrecomputedLines[3][25].A0 = emulated.ValueOf[emulated.BN254Fp]("10044194885594442142147609155416421922257472356796305453772928783393388899067") + PrecomputedLines[3][25].A1 = emulated.ValueOf[emulated.BN254Fp]("18263760812393990074188911070757643007564984011360754662297829292657271464150") + // i = 24 + PrecomputedLines[0][24].A0 = emulated.ValueOf[emulated.BN254Fp]("3061883141739315602214563556302122268976467112226903626183003265864700797561") + PrecomputedLines[0][24].A1 = emulated.ValueOf[emulated.BN254Fp]("18877577761913543892554361883552980798297636968407602018463962047995764813219") + PrecomputedLines[1][24].A0 = emulated.ValueOf[emulated.BN254Fp]("3126746139599510717075143025311775654763541372790089146406042067854741163507") + PrecomputedLines[1][24].A1 = emulated.ValueOf[emulated.BN254Fp]("17167127260886017815263816291999617868265214449605859656444200044917253310911") + // i = 23 + PrecomputedLines[0][23].A0 = emulated.ValueOf[emulated.BN254Fp]("16870441941931526976119500440658808713972889735578011962941661985358104161131") + PrecomputedLines[0][23].A1 = emulated.ValueOf[emulated.BN254Fp]("6001109932250934659271452903112877161042944352421424685637313477276198572097") + PrecomputedLines[1][23].A0 = emulated.ValueOf[emulated.BN254Fp]("16637142689871529585590417483409806422524779986734732708338690138313396058389") + PrecomputedLines[1][23].A1 = emulated.ValueOf[emulated.BN254Fp]("1309909770792043822424896387657498550838255076099809006357764402595094108197") + PrecomputedLines[2][23].A0 = emulated.ValueOf[emulated.BN254Fp]("18549939322311775505081964944043714528572927965273709888232607395084548206000") + PrecomputedLines[2][23].A1 = emulated.ValueOf[emulated.BN254Fp]("20399818991869794833324015990855132322105497628093638572012296248349524670190") + PrecomputedLines[3][23].A0 = emulated.ValueOf[emulated.BN254Fp]("7342977470454632355392560878259850035965508595638789766324590700762099135681") + PrecomputedLines[3][23].A1 = emulated.ValueOf[emulated.BN254Fp]("13133475924936788812547393936820540309170125548865828094250523111922518822075") + // i = 22 + PrecomputedLines[0][22].A0 = emulated.ValueOf[emulated.BN254Fp]("4748990156186402568189268203915060520971178335693786330788682736855281763837") + PrecomputedLines[0][22].A1 = emulated.ValueOf[emulated.BN254Fp]("1309123459585246519346967984684303594496756037532773816040933448912027627499") + PrecomputedLines[1][22].A0 = emulated.ValueOf[emulated.BN254Fp]("14774495602218432844736442669860287970046937120486943306048410603925270631316") + PrecomputedLines[1][22].A1 = emulated.ValueOf[emulated.BN254Fp]("7758103039306620389373197481991170462191047006336390431484191868400773850577") + // i = 21 + PrecomputedLines[0][21].A0 = emulated.ValueOf[emulated.BN254Fp]("20467216100325522645996376085496391619753268832330437756343044572011862940545") + PrecomputedLines[0][21].A1 = emulated.ValueOf[emulated.BN254Fp]("14887102390814534704591565166101282155253192464472393232200334099849552058977") + PrecomputedLines[1][21].A0 = emulated.ValueOf[emulated.BN254Fp]("21078606515401393469046677323685463512748861024754049000789577483054033214476") + PrecomputedLines[1][21].A1 = emulated.ValueOf[emulated.BN254Fp]("4564303136472462460176799031863979133770926707479849529175620334944346563602") + // i = 20 + PrecomputedLines[0][20].A0 = emulated.ValueOf[emulated.BN254Fp]("2934547587293842961452405179156964753642527525482090204385414595783458332955") + PrecomputedLines[0][20].A1 = emulated.ValueOf[emulated.BN254Fp]("13388881467399048052694263240074072556503679672952359032570384893452576065521") + PrecomputedLines[1][20].A0 = emulated.ValueOf[emulated.BN254Fp]("20835022106176713220060462057962830658550225802689628738113269878374295200567") + PrecomputedLines[1][20].A1 = emulated.ValueOf[emulated.BN254Fp]("16905312434058784658661572959194517726558593812351378232506422776680415520438") + // i = 19 + PrecomputedLines[0][19].A0 = emulated.ValueOf[emulated.BN254Fp]("9379108182894698430849303731451612013091506165395461542418088794313609818924") + PrecomputedLines[0][19].A1 = emulated.ValueOf[emulated.BN254Fp]("2907967239075992964474535088161648903162956259142634038906064837044072456436") + PrecomputedLines[1][19].A0 = emulated.ValueOf[emulated.BN254Fp]("20782609764653852909287250407883639725034699195478521117170508454748489944739") + PrecomputedLines[1][19].A1 = emulated.ValueOf[emulated.BN254Fp]("6619607978961240534355495781535994024557310385388068869619586172109347559558") + PrecomputedLines[2][19].A0 = emulated.ValueOf[emulated.BN254Fp]("19532033666164956798033098163550806595466824118787667796975924917702885634055") + PrecomputedLines[2][19].A1 = emulated.ValueOf[emulated.BN254Fp]("21262034978109692747726845264241152568669959314932048093743486024410820587806") + PrecomputedLines[3][19].A0 = emulated.ValueOf[emulated.BN254Fp]("4130709627975833054760191215714688654831243211864203332343057006624671877225") + PrecomputedLines[3][19].A1 = emulated.ValueOf[emulated.BN254Fp]("13472955160176293525849695637350417649009388964115438600570520044550603528353") + // i = 18 + PrecomputedLines[0][18].A0 = emulated.ValueOf[emulated.BN254Fp]("16217086076844556529363966917086131649757503473381572790624279844617553496364") + PrecomputedLines[0][18].A1 = emulated.ValueOf[emulated.BN254Fp]("10713288731971058998221606378429065825113414556825913172815543429990736818376") + PrecomputedLines[1][18].A0 = emulated.ValueOf[emulated.BN254Fp]("12552050545636774575337026751447963392888877749601484102741403327500446156874") + PrecomputedLines[1][18].A1 = emulated.ValueOf[emulated.BN254Fp]("1889715733419853004961620455197495456081564113924277969675613143136837135513") + // i = 17 + PrecomputedLines[0][17].A0 = emulated.ValueOf[emulated.BN254Fp]("281209871248542006516712473420034295321120298208074788459534204367198535142") + PrecomputedLines[0][17].A1 = emulated.ValueOf[emulated.BN254Fp]("11427002786003194988328672063057116278011347987499596141846392618935095919375") + PrecomputedLines[1][17].A0 = emulated.ValueOf[emulated.BN254Fp]("8249292749133127482785740847289832702955066233577280917498121478679299264218") + PrecomputedLines[1][17].A1 = emulated.ValueOf[emulated.BN254Fp]("8641484141124226670682171491065546757600227936199444728945999484137993005385") + PrecomputedLines[2][17].A0 = emulated.ValueOf[emulated.BN254Fp]("16599812170685991386506806371665534372750857143857635056957476184228979064472") + PrecomputedLines[2][17].A1 = emulated.ValueOf[emulated.BN254Fp]("10626908974014138715216269718238311856490393026039163714213665951817067052590") + PrecomputedLines[3][17].A0 = emulated.ValueOf[emulated.BN254Fp]("6357700871494731822410049284748295785098346320944639515384097511493616712415") + PrecomputedLines[3][17].A1 = emulated.ValueOf[emulated.BN254Fp]("18094386343075402462056828857430458085585468610519347879110366976800945012238") + // i = 16 + PrecomputedLines[0][16].A0 = emulated.ValueOf[emulated.BN254Fp]("14531595037104551623335765422361255737783733804534887309336561182386022894609") + PrecomputedLines[0][16].A1 = emulated.ValueOf[emulated.BN254Fp]("13693292038242340138317667055996944950242100560633566577428335327100558601931") + PrecomputedLines[1][16].A0 = emulated.ValueOf[emulated.BN254Fp]("19882300238879412083813185345065778671818694192928769536446198022672502614800") + PrecomputedLines[1][16].A1 = emulated.ValueOf[emulated.BN254Fp]("19023802829192861606681770049155217498069279515832660659542794970260956599076") + // i = 15 + PrecomputedLines[0][15].A0 = emulated.ValueOf[emulated.BN254Fp]("214428669950930239037502001140346676398224589363492833374258252706619225095") + PrecomputedLines[0][15].A1 = emulated.ValueOf[emulated.BN254Fp]("9700482781441182965875593020977851473671521652260949157117738997229458832612") + PrecomputedLines[1][15].A0 = emulated.ValueOf[emulated.BN254Fp]("5707939202694442208311052687419434583655218004451905019053384532157090849511") + PrecomputedLines[1][15].A1 = emulated.ValueOf[emulated.BN254Fp]("9249003779150082855802401674917195099999438620295626021241732108833595933661") + // i = 14 + PrecomputedLines[0][14].A0 = emulated.ValueOf[emulated.BN254Fp]("7805886586080369896587926569311726027953215021897217781996744411655756551999") + PrecomputedLines[0][14].A1 = emulated.ValueOf[emulated.BN254Fp]("4475945661578122172966964851067681040356518221787176436864599099315745378607") + PrecomputedLines[1][14].A0 = emulated.ValueOf[emulated.BN254Fp]("1436676498637654967294854037272027428354069617227014207131783637892060911873") + PrecomputedLines[1][14].A1 = emulated.ValueOf[emulated.BN254Fp]("2305366134511694415797376996283850181652766442173047187160012616555129059726") + PrecomputedLines[2][14].A0 = emulated.ValueOf[emulated.BN254Fp]("3948625959186731630883431212774050798377359369729173092466505343872731816213") + PrecomputedLines[2][14].A1 = emulated.ValueOf[emulated.BN254Fp]("2358246181418771103293301333572732567327456223492056127888059507221213191432") + PrecomputedLines[3][14].A0 = emulated.ValueOf[emulated.BN254Fp]("3214429786637011410044793447049402874657838824607091686668942744472659699178") + PrecomputedLines[3][14].A1 = emulated.ValueOf[emulated.BN254Fp]("3311041522569135867590138245723395084506957431406246791002073444583824677944") + // i = 13 + PrecomputedLines[0][13].A0 = emulated.ValueOf[emulated.BN254Fp]("7940653081686121254560574898875834993172673192201552554589347797020017756575") + PrecomputedLines[0][13].A1 = emulated.ValueOf[emulated.BN254Fp]("15974145135205498451459757927505164636050928369369965530985889343297252806673") + PrecomputedLines[1][13].A0 = emulated.ValueOf[emulated.BN254Fp]("7686257707869567857469834549336963702481639214633150259937731475944847915123") + PrecomputedLines[1][13].A1 = emulated.ValueOf[emulated.BN254Fp]("8581852611449322691153270314388156853032486316219977253503422966968101541654") + // i = 12 + PrecomputedLines[0][12].A0 = emulated.ValueOf[emulated.BN254Fp]("7606607159448995415026938236473845673703071556204161879755865676515825543592") + PrecomputedLines[0][12].A1 = emulated.ValueOf[emulated.BN254Fp]("8956068938006055699967046837110471704135545397148189742180442797078694877053") + PrecomputedLines[1][12].A0 = emulated.ValueOf[emulated.BN254Fp]("11022564885667925490414698424833447218714993895366895350576595161435910155603") + PrecomputedLines[1][12].A1 = emulated.ValueOf[emulated.BN254Fp]("2937380761826300692577553924917170147052369549814500103731019151442626339958") + // i = 11 + PrecomputedLines[0][11].A0 = emulated.ValueOf[emulated.BN254Fp]("12051248266980606591117486251497672077260922365434536343080997678580860092839") + PrecomputedLines[0][11].A1 = emulated.ValueOf[emulated.BN254Fp]("1101948046901408684769644765236167201751826825758237448303491655657473871883") + PrecomputedLines[1][11].A0 = emulated.ValueOf[emulated.BN254Fp]("19040361699025595665325643767494987684082158366896439415230194791631253146950") + PrecomputedLines[1][11].A1 = emulated.ValueOf[emulated.BN254Fp]("19593948793881594577280833499171904486133659865964978128524672716381353415110") + // i = 10 + PrecomputedLines[0][10].A0 = emulated.ValueOf[emulated.BN254Fp]("2695861635377070245469834129082472669158214946653615403814586888514439809659") + PrecomputedLines[0][10].A1 = emulated.ValueOf[emulated.BN254Fp]("3517424455933415445379245336041862787934743784030363349436486456627711568935") + PrecomputedLines[1][10].A0 = emulated.ValueOf[emulated.BN254Fp]("1297121294706129311008967261335655974545941670322230404082189407981090136606") + PrecomputedLines[1][10].A1 = emulated.ValueOf[emulated.BN254Fp]("7269007213738731831376716995811160240537269554959925765999109257469235068970") + PrecomputedLines[2][10].A0 = emulated.ValueOf[emulated.BN254Fp]("11136530791838387464153280137325004767292250958794387187550512501095403277023") + PrecomputedLines[2][10].A1 = emulated.ValueOf[emulated.BN254Fp]("106502518068427873618571970037977909894299363465928161471159218521528331255") + PrecomputedLines[3][10].A0 = emulated.ValueOf[emulated.BN254Fp]("9621932176503134785751044845073759336043415956999820764897251636989455619322") + PrecomputedLines[3][10].A1 = emulated.ValueOf[emulated.BN254Fp]("11988385993623676567525757027648256037399190847305440889617409768798277399896") + // i = 9 + PrecomputedLines[0][9].A0 = emulated.ValueOf[emulated.BN254Fp]("19372185202903332212213460689668073391237276001679830995407820540229540001238") + PrecomputedLines[0][9].A1 = emulated.ValueOf[emulated.BN254Fp]("5616465305834971731681446522628928621391712075162309643124420681983851054131") + PrecomputedLines[1][9].A0 = emulated.ValueOf[emulated.BN254Fp]("15130750112907645494599377984115117159035751050040977084315605482681536587330") + PrecomputedLines[1][9].A1 = emulated.ValueOf[emulated.BN254Fp]("4112821409985026926465852755077338777425800159968275868518690696940060997519") + // i = 8 + PrecomputedLines[0][8].A0 = emulated.ValueOf[emulated.BN254Fp]("1177350099731769374927755229371912682105254009198092299667981153894962892911") + PrecomputedLines[0][8].A1 = emulated.ValueOf[emulated.BN254Fp]("12783556398948310494078028567080245047035120485548200979455238088948172228620") + PrecomputedLines[1][8].A0 = emulated.ValueOf[emulated.BN254Fp]("12038222281185955050483388631945863715711114498195498301547718431783742484642") + PrecomputedLines[1][8].A1 = emulated.ValueOf[emulated.BN254Fp]("9264166105136475526327132322149055656031511662283307277292723856236620490189") + // i = 7 + PrecomputedLines[0][7].A0 = emulated.ValueOf[emulated.BN254Fp]("21486077995282264447124488458477451264448290571356619721096235651158475590070") + PrecomputedLines[0][7].A1 = emulated.ValueOf[emulated.BN254Fp]("14259294332493846514184808308757331443862876275854439881679087349107324605887") + PrecomputedLines[1][7].A0 = emulated.ValueOf[emulated.BN254Fp]("21310181591085388723115701669056667379511330213691931034851726064119888872673") + PrecomputedLines[1][7].A1 = emulated.ValueOf[emulated.BN254Fp]("5075823038830179820764812274438824585993277163376381791462971634098010103671") + PrecomputedLines[2][7].A0 = emulated.ValueOf[emulated.BN254Fp]("13112636618783066585137462084175830499516645653422124722623967609618373134219") + PrecomputedLines[2][7].A1 = emulated.ValueOf[emulated.BN254Fp]("6606562192543571614836403257264220466480425913053458051187575436069152776683") + PrecomputedLines[3][7].A0 = emulated.ValueOf[emulated.BN254Fp]("10163523267982706542267128230964678625528511744949894317202814240082601743946") + PrecomputedLines[3][7].A1 = emulated.ValueOf[emulated.BN254Fp]("6356510014456263346824710768921779981830557122736686713952594920913429569627") + // i = 6 + PrecomputedLines[0][6].A0 = emulated.ValueOf[emulated.BN254Fp]("6993754568930208115692311734335110502575073944799196719243980186522826192500") + PrecomputedLines[0][6].A1 = emulated.ValueOf[emulated.BN254Fp]("8644183822713440026441857877903243046934367497877832294531055867710247192861") + PrecomputedLines[1][6].A0 = emulated.ValueOf[emulated.BN254Fp]("6162215014365963127944303642424328743998922537326512693704891583562046811368") + PrecomputedLines[1][6].A1 = emulated.ValueOf[emulated.BN254Fp]("10154626677971349735949904871597330658975481223864906253916895014919837738925") + // i = 5 + PrecomputedLines[0][5].A0 = emulated.ValueOf[emulated.BN254Fp]("1937978127062894188242539535500798667590225073902686974782593130206599845007") + PrecomputedLines[0][5].A1 = emulated.ValueOf[emulated.BN254Fp]("17119013397235014212137323292998779519533594281540918081901210988375145292838") + PrecomputedLines[1][5].A0 = emulated.ValueOf[emulated.BN254Fp]("7459302385665395083210154080521520536970822621811543022810558924096495990803") + PrecomputedLines[1][5].A1 = emulated.ValueOf[emulated.BN254Fp]("21518092698439432423716311366625905640550440032600495926825070561121714064966") + PrecomputedLines[2][5].A0 = emulated.ValueOf[emulated.BN254Fp]("1037191233256242024591850390288460629996863167044965323205998218470478292293") + PrecomputedLines[2][5].A1 = emulated.ValueOf[emulated.BN254Fp]("3506829931079437251122292736914870037093525843321317425550720179685754066171") + PrecomputedLines[3][5].A0 = emulated.ValueOf[emulated.BN254Fp]("17820182080707051991327549024823189851054031788540897864701560845270403753109") + PrecomputedLines[3][5].A1 = emulated.ValueOf[emulated.BN254Fp]("9899305209455152807161919079699365227282179194615715116357423849253490779682") + // i = 4 + PrecomputedLines[0][4].A0 = emulated.ValueOf[emulated.BN254Fp]("5324324825155514158487515405280394681879278019015840159544809591606158257332") + PrecomputedLines[0][4].A1 = emulated.ValueOf[emulated.BN254Fp]("9103354596188444592813598822287678248145908880267402190573436459100897831528") + PrecomputedLines[1][4].A0 = emulated.ValueOf[emulated.BN254Fp]("18167200785083515576687640058379703244722940349781476704929858352562239814160") + PrecomputedLines[1][4].A1 = emulated.ValueOf[emulated.BN254Fp]("17137305555905969523835852662612004277081830974452307053507378102895278528589") + // i = 3 + PrecomputedLines[0][3].A0 = emulated.ValueOf[emulated.BN254Fp]("10264790697161180899816663648211298388740251958120764796906064810666177499646") + PrecomputedLines[0][3].A1 = emulated.ValueOf[emulated.BN254Fp]("20258752168525716639405362594120556561630156450371244903071234737141561446514") + PrecomputedLines[1][3].A0 = emulated.ValueOf[emulated.BN254Fp]("1047785956871651181652762185265609154515890741811384676308782178186674721333") + PrecomputedLines[1][3].A1 = emulated.ValueOf[emulated.BN254Fp]("11832321406774014396058335342903409824167804020892883894340494373874448406") + PrecomputedLines[2][3].A0 = emulated.ValueOf[emulated.BN254Fp]("14092236220532563722273193993345070075560441348432917207002037677338453582804") + PrecomputedLines[2][3].A1 = emulated.ValueOf[emulated.BN254Fp]("11497587520289618598164153674981445132323171186923426171334582641683949994368") + PrecomputedLines[3][3].A0 = emulated.ValueOf[emulated.BN254Fp]("20758819501863100067579713445623163018598253994701203661225784213541257554016") + PrecomputedLines[3][3].A1 = emulated.ValueOf[emulated.BN254Fp]("13856843617321777013534357727779974444221810617686309708742406908090120962388") + // i = 2 + PrecomputedLines[0][2].A0 = emulated.ValueOf[emulated.BN254Fp]("7373558527687620561422152162991916593489921204438285449301245125513755320314") + PrecomputedLines[0][2].A1 = emulated.ValueOf[emulated.BN254Fp]("11221203116796205233830200618441749399293664601851323715931973686288216972676") + PrecomputedLines[1][2].A0 = emulated.ValueOf[emulated.BN254Fp]("5779253402567033185739372381451694349656589328372930738070702115143767419084") + PrecomputedLines[1][2].A1 = emulated.ValueOf[emulated.BN254Fp]("14510717290459248510467845044717011062002712332273480525810490069657559246488") + // i = 1 + PrecomputedLines[0][1].A0 = emulated.ValueOf[emulated.BN254Fp]("4363337419110373219314875355410347765942328123649693011784654840812725908680") + PrecomputedLines[0][1].A1 = emulated.ValueOf[emulated.BN254Fp]("14128521847906711846651015249783668960206638766994768661570084127609973543329") + PrecomputedLines[1][1].A0 = emulated.ValueOf[emulated.BN254Fp]("14274248424128087078986321168245052972625654514554370562089216041782293981548") + PrecomputedLines[1][1].A1 = emulated.ValueOf[emulated.BN254Fp]("1819524229057471418630268602559233583119735893806768677221723572182212858124") + // i = 0 + PrecomputedLines[0][0].A0 = emulated.ValueOf[emulated.BN254Fp]("9362219973542874570450638939162889131446156209210285596288463924394915480984") + PrecomputedLines[0][0].A1 = emulated.ValueOf[emulated.BN254Fp]("2166292944666058936308575452356594861484875658446920124101542049719645145111") + PrecomputedLines[1][0].A0 = emulated.ValueOf[emulated.BN254Fp]("4110752764440847029993710333296870396753666785458870337440969782289123199548") + PrecomputedLines[1][0].A1 = emulated.ValueOf[emulated.BN254Fp]("2524539108828422865916215076272533083788815095038371749085255761852573579569") + + // precompute ℓ_{[6x₀+2]G,π(G)} and ℓ_{[6x₀+2]Q+π(Q),-π²(Q)} + PrecomputedLines[0][65].A0 = emulated.ValueOf[emulated.BN254Fp]("1783675334639145815870644667302053681284809203074760789174282843314959992696") + PrecomputedLines[0][65].A1 = emulated.ValueOf[emulated.BN254Fp]("1951629468503798175241267767783740387858451447279377240163542122176909999042") + PrecomputedLines[1][65].A0 = emulated.ValueOf[emulated.BN254Fp]("11606810498377529077474160506241087507512991084990061027495602653306616416578") + PrecomputedLines[1][65].A1 = emulated.ValueOf[emulated.BN254Fp]("3352534584315050939412504680522542162196297327743431591269455212982072101125") + PrecomputedLines[0][66].A0 = emulated.ValueOf[emulated.BN254Fp]("6096279428236379570154933180749579153826262770836405799784606344740772912667") + PrecomputedLines[0][66].A1 = emulated.ValueOf[emulated.BN254Fp]("14973662014392090789260536656747094881206477852734022080621808568128746451734") + PrecomputedLines[1][66].A0 = emulated.ValueOf[emulated.BN254Fp]("9419615873784151968180451321627355717720222550586076973652377412738922144509") + PrecomputedLines[1][66].A1 = emulated.ValueOf[emulated.BN254Fp]("9757224935408026300679308061023777371092034675697842738078668298357133523844") + + return PrecomputedLines +} diff --git a/std/algebra/weierstrass/doc.go b/std/algebra/emulated/sw_emulated/doc.go similarity index 68% rename from std/algebra/weierstrass/doc.go rename to std/algebra/emulated/sw_emulated/doc.go index 72ff229dad..dc83a96fcb 100644 --- a/std/algebra/weierstrass/doc.go +++ b/std/algebra/emulated/sw_emulated/doc.go @@ -1,5 +1,5 @@ /* -Package weierstrass implements elliptic curve group operations in (short) +Package sw_emulated implements elliptic curve group operations in (short) Weierstrass form. The elliptic curve is the set of points (X,Y) satisfying the equation: @@ -10,6 +10,10 @@ over some base field 𝐅p for some constants a, b ∈ 𝐅p. Additionally, for every curve we also define its generator (base point) G. All these parameters are stored in the variable of type [CurveParams]. +This package implements unified and complete point addition. The method +[Curve.AddUnified] can be used for point additions or in case of points at +infinity. As such, this package does not expose separate Add and Double methods. + The package provides a few curve parameters, see functions [GetSecp256k1Params] and [GetBN254Params]. @@ -22,12 +26,9 @@ field. For now, we only have a single curve defined on every base field, but this may change in the future with the addition of additional curves. This package uses field emulation (unlike packages -[github.com/consensys/gnark/std/algebra/sw_bls12377] and -[github.com/consensys/gnark/std/algebra/sw_bls24315], which use 2-chains). This +[github.com/consensys/gnark/std/algebra/native/sw_bls12377] and +[github.com/consensys/gnark/std/algebra/native/sw_bls24315], which use 2-chains). This allows to use any curve over any native (SNARK) field. The drawback of this -approach is the extreme cost of the operations. In R1CS, point addition on -256-bit fields is approximately 3500 constraints and doubling is approximately -4300 constraints. A full scalar multiplication is approximately 2M constraints. -It is several times more in PLONKish aritmetisation. +approach is the extreme cost of the operations. */ -package weierstrass +package sw_emulated diff --git a/std/algebra/weierstrass/doc_test.go b/std/algebra/emulated/sw_emulated/doc_test.go similarity index 87% rename from std/algebra/weierstrass/doc_test.go rename to std/algebra/emulated/sw_emulated/doc_test.go index cfa81a1fe0..2d63969ecf 100644 --- a/std/algebra/weierstrass/doc_test.go +++ b/std/algebra/emulated/sw_emulated/doc_test.go @@ -1,4 +1,4 @@ -package weierstrass_test +package sw_emulated_test import ( "fmt" @@ -9,16 +9,16 @@ import ( "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/weierstrass" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" ) type ExampleCurveCircuit[Base, Scalar emulated.FieldParams] struct { - Res weierstrass.AffinePoint[Base] + Res sw_emulated.AffinePoint[Base] } func (c *ExampleCurveCircuit[B, S]) Define(api frontend.API) error { - curve, err := weierstrass.New[B, S](api, weierstrass.GetCurveParams[emulated.BN254Fp]()) + curve, err := sw_emulated.New[B, S](api, sw_emulated.GetCurveParams[emulated.BN254Fp]()) if err != nil { panic("initalize new curve") } @@ -27,7 +27,7 @@ func (c *ExampleCurveCircuit[B, S]) Define(api frontend.API) error { g4 := curve.ScalarMul(G, &scalar4) // 4*G scalar5 := emulated.ValueOf[S](5) g5 := curve.ScalarMul(G, &scalar5) // 5*G - g9 := curve.Add(g4, g5) // 9*G + g9 := curve.AddUnified(g4, g5) // 9*G curve.AssertIsEqual(g9, &c.Res) return nil } @@ -41,7 +41,7 @@ func ExampleCurve() { circuit := ExampleCurveCircuit[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} witness := ExampleCurveCircuit[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - Res: weierstrass.AffinePoint[emulated.Secp256k1Fp]{ + Res: sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), }, diff --git a/std/algebra/emulated/sw_emulated/params.go b/std/algebra/emulated/sw_emulated/params.go new file mode 100644 index 0000000000..efce7a1566 --- /dev/null +++ b/std/algebra/emulated/sw_emulated/params.go @@ -0,0 +1,134 @@ +package sw_emulated + +import ( + "crypto/elliptic" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + "github.com/consensys/gnark/std/math/emulated" +) + +// CurveParams defines parameters of an elliptic curve in short Weierstrass form +// given by the equation +// +// Y² = X³ + aX + b +// +// The base point is defined by (Gx, Gy). +type CurveParams struct { + A *big.Int // a in curve equation + B *big.Int // b in curve equation + Gx *big.Int // base point x + Gy *big.Int // base point y + Gm [][2]*big.Int // m*base point coords +} + +// GetSecp256k1Params returns curve parameters for the curve secp256k1. When +// initialising new curve, use the base field [emulated.Secp256k1Fp] and scalar +// field [emulated.Secp256k1Fr]. +func GetSecp256k1Params() CurveParams { + _, g1aff := secp256k1.Generators() + return CurveParams{ + A: big.NewInt(0), + B: big.NewInt(7), + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeSecp256k1Table(), + } +} + +// GetBN254Params returns the curve parameters for the curve BN254 (alt_bn128). +// When initialising new curve, use the base field [emulated.BN254Fp] and scalar +// field [emulated.BN254Fr]. +func GetBN254Params() CurveParams { + _, _, g1aff, _ := bn254.Generators() + return CurveParams{ + A: big.NewInt(0), + B: big.NewInt(3), + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeBN254Table(), + } +} + +// GetBLS12381Params returns the curve parameters for the curve BLS12-381. +// When initialising new curve, use the base field [emulated.BLS12381Fp] and scalar +// field [emulated.BLS12381Fr]. +func GetBLS12381Params() CurveParams { + _, _, g1aff, _ := bls12381.Generators() + return CurveParams{ + A: big.NewInt(0), + B: big.NewInt(4), + Gx: g1aff.X.BigInt(new(big.Int)), + Gy: g1aff.Y.BigInt(new(big.Int)), + Gm: computeBLS12381Table(), + } +} + +// GetP256Params returns the curve parameters for the curve P-256 (also +// SECP256r1). When initialising new curve, use the base field +// [emulated.P256Fp] and scalar field [emulated.P256Fr]. +func GetP256Params() CurveParams { + pr := elliptic.P256().Params() + a := new(big.Int).Sub(pr.P, big.NewInt(3)) + return CurveParams{ + A: a, + B: pr.B, + Gx: pr.Gx, + Gy: pr.Gy, + Gm: computeP256Table(), + } +} + +// GetP384Params returns the curve parameters for the curve P-384 (also +// SECP384r1). When initialising new curve, use the base field +// [emulated.P384Fp] and scalar field [emulated.P384Fr]. +func GetP384Params() CurveParams { + pr := elliptic.P384().Params() + a := new(big.Int).Sub(pr.P, big.NewInt(3)) + return CurveParams{ + A: a, + B: pr.B, + Gx: pr.Gx, + Gy: pr.Gy, + Gm: computeP384Table(), + } +} + +// GetCurveParams returns suitable curve parameters given the parametric type +// Base as base field. It caches the parameters and modifying the values in the +// parameters struct leads to undefined behaviour. +func GetCurveParams[Base emulated.FieldParams]() CurveParams { + var t Base + switch t.Modulus().String() { + case emulated.Secp256k1Fp{}.Modulus().String(): + return secp256k1Params + case emulated.BN254Fp{}.Modulus().String(): + return bn254Params + case emulated.BLS12381Fp{}.Modulus().String(): + return bls12381Params + case emulated.P256Fp{}.Modulus().String(): + return p256Params + case emulated.P384Fp{}.Modulus().String(): + return p384Params + default: + panic("no stored parameters") + } +} + +var ( + secp256k1Params CurveParams + bn254Params CurveParams + bls12381Params CurveParams + p256Params CurveParams + p384Params CurveParams +) + +func init() { + secp256k1Params = GetSecp256k1Params() + bn254Params = GetBN254Params() + bls12381Params = GetBLS12381Params() + p256Params = GetP256Params() + p384Params = GetP384Params() +} diff --git a/std/algebra/emulated/sw_emulated/params_compute.go b/std/algebra/emulated/sw_emulated/params_compute.go new file mode 100644 index 0000000000..5eaf21e87b --- /dev/null +++ b/std/algebra/emulated/sw_emulated/params_compute.go @@ -0,0 +1,132 @@ +package sw_emulated + +import ( + "crypto/elliptic" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/secp256k1" +) + +func computeSecp256k1Table() [][2]*big.Int { + Gjac, _ := secp256k1.Generators() + table := make([][2]*big.Int, 256) + tmp := new(secp256k1.G1Jac).Set(&Gjac) + aff := new(secp256k1.G1Affine) + jac := new(secp256k1.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} + +func computeBN254Table() [][2]*big.Int { + Gjac, _, _, _ := bn254.Generators() + table := make([][2]*big.Int, 256) + tmp := new(bn254.G1Jac).Set(&Gjac) + aff := new(bn254.G1Affine) + jac := new(bn254.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table +} + +func computeBLS12381Table() [][2]*big.Int { + Gjac, _, _, _ := bls12381.Generators() + table := make([][2]*big.Int, 256) + tmp := new(bls12381.G1Jac).Set(&Gjac) + aff := new(bls12381.G1Affine) + jac := new(bls12381.G1Jac) + for i := 1; i < 256; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&Gjac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table +} + +func computeP256Table() [][2]*big.Int { + table := make([][2]*big.Int, 256) + p256 := elliptic.P256() + gx, gy := p256.Params().Gx, p256.Params().Gy + tmpx, tmpy := new(big.Int).Set(gx), new(big.Int).Set(gy) + for i := 1; i < 256; i++ { + tmpx, tmpy = p256.Double(tmpx, tmpy) + switch i { + case 1, 2: + xx, yy := p256.Add(tmpx, tmpy, gx, gy) + table[i-1] = [2]*big.Int{xx, yy} + case 3: + xx, yy := p256.Add(tmpx, tmpy, gx, new(big.Int).Sub(p256.Params().P, gy)) + table[i-1] = [2]*big.Int{xx, yy} + fallthrough + default: + table[i] = [2]*big.Int{tmpx, tmpy} + } + } + return table +} + +func computeP384Table() [][2]*big.Int { + table := make([][2]*big.Int, 384) + p384 := elliptic.P384() + gx, gy := p384.Params().Gx, p384.Params().Gy + tmpx, tmpy := new(big.Int).Set(gx), new(big.Int).Set(gy) + for i := 1; i < 384; i++ { + tmpx, tmpy = p384.Double(tmpx, tmpy) + switch i { + case 1, 2: + xx, yy := p384.Add(tmpx, tmpy, gx, gy) + table[i-1] = [2]*big.Int{xx, yy} + case 3: + xx, yy := p384.Add(tmpx, tmpy, gx, new(big.Int).Sub(p384.Params().P, gy)) + table[i-1] = [2]*big.Int{xx, yy} + fallthrough + default: + table[i] = [2]*big.Int{tmpx, tmpy} + } + } + return table +} diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go new file mode 100644 index 0000000000..0dbf653b6a --- /dev/null +++ b/std/algebra/emulated/sw_emulated/point.go @@ -0,0 +1,587 @@ +package sw_emulated + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +// New returns a new [Curve] instance over the base field Base and scalar field +// Scalars defined by the curve parameters params. It returns an error if +// initialising the field emulation fails (for example, when the native field is +// too small) or when the curve parameters are incompatible with the fields. +func New[Base, Scalars emulated.FieldParams](api frontend.API, params CurveParams) (*Curve[Base, Scalars], error) { + ba, err := emulated.NewField[Base](api) + if err != nil { + return nil, fmt.Errorf("new base api: %w", err) + } + sa, err := emulated.NewField[Scalars](api) + if err != nil { + return nil, fmt.Errorf("new scalar api: %w", err) + } + emuGm := make([]AffinePoint[Base], len(params.Gm)) + for i, v := range params.Gm { + emuGm[i] = AffinePoint[Base]{emulated.ValueOf[Base](v[0]), emulated.ValueOf[Base](v[1])} + } + Gx := emulated.ValueOf[Base](params.Gx) + Gy := emulated.ValueOf[Base](params.Gy) + return &Curve[Base, Scalars]{ + params: params, + api: api, + baseApi: ba, + scalarApi: sa, + g: AffinePoint[Base]{ + X: Gx, + Y: Gy, + }, + gm: emuGm, + a: emulated.ValueOf[Base](params.A), + b: emulated.ValueOf[Base](params.B), + addA: params.A.Cmp(big.NewInt(0)) != 0, + }, nil +} + +// Curve is an initialised curve which allows performing group operations. +type Curve[Base, Scalars emulated.FieldParams] struct { + // params is the parameters of the curve + params CurveParams + // api is the native api, we construct it ourselves to be sure + api frontend.API + // baseApi is the api for point operations + baseApi *emulated.Field[Base] + // scalarApi is the api for scalar operations + scalarApi *emulated.Field[Scalars] + + // g is the generator (base point) of the curve. + g AffinePoint[Base] + + // gm are the pre-computed multiples the generator (base point) of the curve. + gm []AffinePoint[Base] + + a emulated.Element[Base] + b emulated.Element[Base] + addA bool +} + +// Generator returns the base point of the curve. The method does not copy and +// modifying the returned element leads to undefined behaviour! +func (c *Curve[B, S]) Generator() *AffinePoint[B] { + return &c.g +} + +// GeneratorMultiples returns the pre-computed multiples of the base point of the curve. The method does not copy and +// modifying the returned element leads to undefined behaviour! +func (c *Curve[B, S]) GeneratorMultiples() []AffinePoint[B] { + return c.gm +} + +// AffinePoint represents a point on the elliptic curve. We do not check that +// the point is actually on the curve. +// +// Point (0,0) represents point at the infinity. This representation is +// compatible with the EVM representations of points at infinity. +type AffinePoint[Base emulated.FieldParams] struct { + X, Y emulated.Element[Base] +} + +// Neg returns an inverse of p. It doesn't modify p. +func (c *Curve[B, S]) Neg(p *AffinePoint[B]) *AffinePoint[B] { + return &AffinePoint[B]{ + X: p.X, + Y: *c.baseApi.Neg(&p.Y), + } +} + +// AssertIsEqual asserts that p and q are the same point. +func (c *Curve[B, S]) AssertIsEqual(p, q *AffinePoint[B]) { + c.baseApi.AssertIsEqual(&p.X, &q.X) + c.baseApi.AssertIsEqual(&p.Y, &q.Y) +} + +// add adds p and q and returns it. It doesn't modify p nor q. +// +// ⚠️ p must be different than q and -q, and both nonzero. +// +// It uses incomplete formulas in affine coordinates. +func (c *Curve[B, S]) add(p, q *AffinePoint[B]) *AffinePoint[B] { + // compute λ = (q.y-p.y)/(q.x-p.x) + qypy := c.baseApi.Sub(&q.Y, &p.Y) + qxpx := c.baseApi.Sub(&q.X, &p.X) + λ := c.baseApi.Div(qypy, qxpx) + + // xr = λ²-p.x-q.x + λλ := c.baseApi.MulMod(λ, λ) + qxpx = c.baseApi.Add(&p.X, &q.X) + xr := c.baseApi.Sub(λλ, qxpx) + + // p.y = λ(p.x-r.x) - p.y + pxrx := c.baseApi.Sub(&p.X, xr) + λpxrx := c.baseApi.MulMod(λ, pxrx) + yr := c.baseApi.Sub(λpxrx, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } +} + +// AssertIsOnCurve asserts if p belongs to the curve. It doesn't modify p. +func (c *Curve[B, S]) AssertIsOnCurve(p *AffinePoint[B]) { + // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) + + // if p=(0,0) we assign b=0 and continue + selector := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y)) + b := c.baseApi.Select(selector, c.baseApi.Zero(), &c.b) + + left := c.baseApi.Mul(&p.Y, &p.Y) + right := c.baseApi.Mul(&p.X, c.baseApi.Mul(&p.X, &p.X)) + right = c.baseApi.Add(right, b) + if c.addA { + ax := c.baseApi.Mul(&c.a, &p.X) + right = c.baseApi.Add(right, ax) + } + c.baseApi.AssertIsEqual(left, right) +} + +// AddUnified adds p and q and returns it. It doesn't modify p nor q. +// +// ✅ p can be equal to q, and either or both can be (0,0). +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// It uses the unified formulas of Brier and Joye ([[BriJoy02]] (Corollary 1)). +// +// [BriJoy02]: https://link.springer.com/content/pdf/10.1007/3-540-45664-3_24.pdf +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) AddUnified(p, q *AffinePoint[B]) *AffinePoint[B] { + + // selector1 = 1 when p is (0,0) and 0 otherwise + selector1 := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y)) + // selector2 = 1 when q is (0,0) and 0 otherwise + selector2 := c.api.And(c.baseApi.IsZero(&q.X), c.baseApi.IsZero(&q.Y)) + + // λ = ((p.x+q.x)² - p.x*q.x + a)/(p.y + q.y) + pxqx := c.baseApi.MulMod(&p.X, &q.X) + pxplusqx := c.baseApi.Add(&p.X, &q.X) + num := c.baseApi.MulMod(pxplusqx, pxplusqx) + num = c.baseApi.Sub(num, pxqx) + if c.addA { + num = c.baseApi.Add(num, &c.a) + } + denum := c.baseApi.Add(&p.Y, &q.Y) + // if p.y + q.y = 0, assign dummy 1 to denum and continue + selector3 := c.baseApi.IsZero(denum) + denum = c.baseApi.Select(selector3, c.baseApi.One(), denum) + λ := c.baseApi.Div(num, denum) + + // x = λ^2 - p.x - q.x + xr := c.baseApi.MulMod(λ, λ) + xr = c.baseApi.Sub(xr, pxplusqx) + + // y = λ(p.x - xr) - p.y + yr := c.baseApi.Sub(&p.X, xr) + yr = c.baseApi.MulMod(yr, λ) + yr = c.baseApi.Sub(yr, &p.Y) + result := AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } + + zero := c.baseApi.Zero() + infinity := AffinePoint[B]{X: *zero, Y: *zero} + // if p=(0,0) return q + result = *c.Select(selector1, q, &result) + // if q=(0,0) return p + result = *c.Select(selector2, p, &result) + // if p.y + q.y = 0, return (0, 0) + result = *c.Select(selector3, &infinity, &result) + + return &result +} + +// double doubles p and return it. It doesn't modify p. +// +// ⚠️ p.Y must be nonzero. +// +// It uses affine coordinates. +func (c *Curve[B, S]) double(p *AffinePoint[B]) *AffinePoint[B] { + + // compute λ = (3p.x²+a)/2*p.y, here we assume a=0 (j invariant 0 curve) + xx3a := c.baseApi.MulMod(&p.X, &p.X) + xx3a = c.baseApi.MulConst(xx3a, big.NewInt(3)) + if c.addA { + xx3a = c.baseApi.Add(xx3a, &c.a) + } + y2 := c.baseApi.MulConst(&p.Y, big.NewInt(2)) + λ := c.baseApi.Div(xx3a, y2) + + // xr = λ²-2p.x + x2 := c.baseApi.MulConst(&p.X, big.NewInt(2)) + λλ := c.baseApi.MulMod(λ, λ) + xr := c.baseApi.Sub(λλ, x2) + + // yr = λ(p-xr) - p.y + pxrx := c.baseApi.Sub(&p.X, xr) + λpxrx := c.baseApi.MulMod(λ, pxrx) + yr := c.baseApi.Sub(λpxrx, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } +} + +// triple triples p and return it. It follows [ELM03] (Section 3.1). +// Saves the computation of the y coordinate of 2p as it is used only in the computation of λ2, +// which can be computed as +// +// λ2 = -λ1-2*p.y/(x2-p.x) +// +// instead. It doesn't modify p. +// +// ⚠️ p.Y must be nonzero. +// +// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf +func (c *Curve[B, S]) triple(p *AffinePoint[B]) *AffinePoint[B] { + + // compute λ1 = (3p.x²+a)/2p.y, here we assume a=0 (j invariant 0 curve) + xx := c.baseApi.MulMod(&p.X, &p.X) + xx = c.baseApi.MulConst(xx, big.NewInt(3)) + if c.addA { + xx = c.baseApi.Add(xx, &c.a) + } + y2 := c.baseApi.MulConst(&p.Y, big.NewInt(2)) + λ1 := c.baseApi.Div(xx, y2) + + // xr = λ1²-2p.x + x2 := c.baseApi.MulConst(&p.X, big.NewInt(2)) + λ1λ1 := c.baseApi.MulMod(λ1, λ1) + x2 = c.baseApi.Sub(λ1λ1, x2) + + // ommit y2 computation, and + // compute λ2 = 2p.y/(x2 − p.x) − λ1. + x1x2 := c.baseApi.Sub(&p.X, x2) + λ2 := c.baseApi.Div(y2, x1x2) + λ2 = c.baseApi.Sub(λ2, λ1) + + // xr = λ²-p.x-x2 + λ2λ2 := c.baseApi.MulMod(λ2, λ2) + qxrx := c.baseApi.Add(x2, &p.X) + xr := c.baseApi.Sub(λ2λ2, qxrx) + + // yr = λ(p.x-xr) - p.y + pxrx := c.baseApi.Sub(&p.X, xr) + λ2pxrx := c.baseApi.MulMod(λ2, pxrx) + yr := c.baseApi.Sub(λ2pxrx, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(xr), + Y: *c.baseApi.Reduce(yr), + } +} + +// doubleAndAdd computes 2p+q as (p+q)+p. It follows [ELM03] (Section 3.1) +// Saves the computation of the y coordinate of p+q as it is used only in the computation of λ2, +// which can be computed as +// +// λ2 = -λ1-2*p.y/(x2-p.x) +// +// instead. It doesn't modify p nor q. +// +// ⚠️ p must be different than q and -q, and both nonzero. +// +// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf +func (c *Curve[B, S]) doubleAndAdd(p, q *AffinePoint[B]) *AffinePoint[B] { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := c.baseApi.Sub(&q.Y, &p.Y) + xqxp := c.baseApi.Sub(&q.X, &p.X) + λ1 := c.baseApi.Div(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := c.baseApi.MulMod(λ1, λ1) + xqxp = c.baseApi.Add(&p.X, &q.X) + x2 := c.baseApi.Sub(λ1λ1, xqxp) + + // ommit y2 computation + // compute λ2 = -λ1-2*p.y/(x2-p.x) + ypyp := c.baseApi.Add(&p.Y, &p.Y) + x2xp := c.baseApi.Sub(x2, &p.X) + λ2 := c.baseApi.Div(ypyp, x2xp) + λ2 = c.baseApi.Add(λ1, λ2) + λ2 = c.baseApi.Neg(λ2) + + // compute x3 =λ2²-p.x-x3 + λ2λ2 := c.baseApi.MulMod(λ2, λ2) + x3 := c.baseApi.Sub(λ2λ2, &p.X) + x3 = c.baseApi.Sub(x3, x2) + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := c.baseApi.Sub(&p.X, x3) + y3 = c.baseApi.Mul(λ2, y3) + y3 = c.baseApi.Sub(y3, &p.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(x3), + Y: *c.baseApi.Reduce(y3), + } + +} + +// doubleAndAddSelect is the same as doubleAndAdd but computes either: +// +// 2p+q is b=1 or +// 2q+p is b=0 +// +// It first computes the x-coordinate of p+q via the slope(p,q) +// and then based on a Select adds either p or q. +func (c *Curve[B, S]) doubleAndAddSelect(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := c.baseApi.Sub(&q.Y, &p.Y) + xqxp := c.baseApi.Sub(&q.X, &p.X) + λ1 := c.baseApi.Div(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := c.baseApi.MulMod(λ1, λ1) + xqxp = c.baseApi.Add(&p.X, &q.X) + x2 := c.baseApi.Sub(λ1λ1, xqxp) + + // ommit y2 computation + + // conditional second addition + t := c.Select(b, p, q) + + // compute λ2 = -λ1-2*t.y/(x2-t.x) + ypyp := c.baseApi.Add(&t.Y, &t.Y) + x2xp := c.baseApi.Sub(x2, &t.X) + λ2 := c.baseApi.Div(ypyp, x2xp) + λ2 = c.baseApi.Add(λ1, λ2) + λ2 = c.baseApi.Neg(λ2) + + // compute x3 =λ2²-t.x-x3 + λ2λ2 := c.baseApi.MulMod(λ2, λ2) + x3 := c.baseApi.Sub(λ2λ2, &t.X) + x3 = c.baseApi.Sub(x3, x2) + + // compute y3 = λ2*(t.x - x3)-t.y + y3 := c.baseApi.Sub(&t.X, x3) + y3 = c.baseApi.Mul(λ2, y3) + y3 = c.baseApi.Sub(y3, &t.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(x3), + Y: *c.baseApi.Reduce(y3), + } + +} + +// Select selects between p and q given the selector b. If b == 1, then returns +// p and q otherwise. +func (c *Curve[B, S]) Select(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { + x := c.baseApi.Select(b, &p.X, &q.X) + y := c.baseApi.Select(b, &p.Y, &q.Y) + return &AffinePoint[B]{ + X: *x, + Y: *y, + } +} + +// Lookup2 performs a 2-bit lookup between i0, i1, i2, i3 based on bits b0 +// and b1. Returns: +// - i0 if b0=0 and b1=0, +// - i1 if b0=1 and b1=0, +// - i2 if b0=0 and b1=1, +// - i3 if b0=1 and b1=1. +func (c *Curve[B, S]) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 *AffinePoint[B]) *AffinePoint[B] { + x := c.baseApi.Lookup2(b0, b1, &i0.X, &i1.X, &i2.X, &i3.X) + y := c.baseApi.Lookup2(b0, b1, &i0.Y, &i1.Y, &i2.Y, &i3.Y) + return &AffinePoint[B]{ + X: *x, + Y: *y, + } +} + +// ScalarMul computes s * p and returns it. It doesn't modify p nor s. +// This function doesn't check that the p is on the curve. See AssertIsOnCurve. +// +// ✅ p can can be (0,0) and s can be 0. +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// It computes the right-to-left variable-base double-and-add algorithm ([Joye07], Alg.1). +// +// Since we use incomplete formulas for the addition law, we need to start with +// a non-zero accumulator point (R0). To do this, we skip the LSB (bit at +// position 0) and proceed assuming it was 1. At the end, we conditionally +// subtract the initial value (p) if LSB is 1. We also handle the bits at +// positions 1 and n-1 outside of the loop to optimize the number of +// constraints using [ELM03] (Section 3.1) +// +// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +// [Joye07]: https://www.iacr.org/archive/ches2007/47270135/47270135.pdf +func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *AffinePoint[B] { + + // if p=(0,0) we assign a dummy (0,1) to p and continue + selector := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y)) + one := c.baseApi.One() + p = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, p) + + var st S + sr := c.scalarApi.Reduce(s) + sBits := c.scalarApi.ToBits(sr) + n := st.Modulus().BitLen() + + // i = 1 + Rb := c.triple(p) + R0 := c.Select(sBits[1], Rb, p) + R1 := c.Select(sBits[1], p, Rb) + + for i := 2; i < n-1; i++ { + Rb = c.doubleAndAddSelect(sBits[i], R0, R1) + R0 = c.Select(sBits[i], Rb, R0) + R1 = c.Select(sBits[i], R1, Rb) + } + + // i = n-1 + Rb = c.doubleAndAddSelect(sBits[n-1], R0, R1) + R0 = c.Select(sBits[n-1], Rb, R0) + + // i = 0 + // we use AddUnified here instead of add so that when s=0, res=(0,0) + // because AddUnified(p, -p) = (0,0) + R0 = c.Select(sBits[0], R0, c.AddUnified(R0, c.Neg(p))) + + // if p=(0,0), return (0,0) + zero := c.baseApi.Zero() + R0 = c.Select(selector, &AffinePoint[B]{X: *zero, Y: *zero}, R0) + + return R0 +} + +// ScalarMulBase computes s * g and returns it, where g is the fixed generator. +// It doesn't modify s. +// +// ✅ When s=0, it returns (0,0). +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// It computes the standard little-endian fixed-base double-and-add algorithm +// [HMV04] (Algorithm 3.26), with the points [2^i]g precomputed. The bits at +// positions 1 and 2 are handled outside of the loop to optimize the number of +// constraints using a Lookup2 with pre-computed [3]g, [5]g and [7]g points. +// +// [HMV04]: https://link.springer.com/book/10.1007/b97644 +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S]) *AffinePoint[B] { + g := c.Generator() + gm := c.GeneratorMultiples() + + var st S + sr := c.scalarApi.Reduce(s) + sBits := c.scalarApi.ToBits(sr) + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res := c.Lookup2(sBits[1], sBits[2], g, &gm[0], &gm[1], &gm[2]) + + for i := 3; i < st.Modulus().BitLen(); i++ { + // gm[i] = [2^i]g + tmp := c.add(res, &gm[i]) + res = c.Select(sBits[i], tmp, res) + } + + // i = 0 + tmp := c.AddUnified(res, c.Neg(g)) + res = c.Select(sBits[0], res, tmp) + + return res +} + +// JointScalarMulBase computes s2 * p + s1 * g and returns it, where g is the +// fixed generator. It doesn't modify p, s1 and s2. +// +// ⚠️ p must NOT be (0,0). +// ⚠️ s1 and s2 must NOT be 0. +// +// It uses the logic from ScalarMul() for s1 * g and the logic from ScalarMulBase() for s2 * g. +// +// JointScalarMulBase is used to verify an ECDSA signature (r,s) on the +// secp256k1 curve. In this case, p is a public key, s2=r/s and s1=hash/s. +// - hash cannot be 0, because of pre-image resistance. +// - r cannot be 0, because r is the x coordinate of a random point on +// secp256k1 (y²=x³+7 mod p) and 7 is not a square mod p. For any other +// curve, (_,0) is a point of order 2 which is not the prime subgroup. +// - (0,0) is not a valid public key. +// +// The [EVM] specifies these checks, wich are performed on the zkEVM +// arithmetization side before calling the circuit that uses this method. +// +// This saves the Select logic related to (0,0) and the use of AddUnified to +// handle the 0-scalar edge case. +func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Element[S]) *AffinePoint[B] { + g := c.Generator() + gm := c.GeneratorMultiples() + + var st S + s1r := c.scalarApi.Reduce(s1) + s1Bits := c.scalarApi.ToBits(s1r) + s2r := c.scalarApi.Reduce(s2) + s2Bits := c.scalarApi.ToBits(s2r) + n := st.Modulus().BitLen() + + // fixed-base + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res1 := c.Lookup2(s1Bits[1], s1Bits[2], g, &gm[0], &gm[1], &gm[2]) + // var-base + // i = 1 + Rb := c.triple(p) + R0 := c.Select(s2Bits[1], Rb, p) + R1 := c.Select(s2Bits[1], p, Rb) + // i = 2 + Rb = c.doubleAndAddSelect(s2Bits[2], R0, R1) + R0 = c.Select(s2Bits[2], Rb, R0) + R1 = c.Select(s2Bits[2], R1, Rb) + + for i := 3; i <= n-3; i++ { + // fixed-base + // gm[i] = [2^i]g + tmp1 := c.add(res1, &gm[i]) + res1 = c.Select(s1Bits[i], tmp1, res1) + // var-base + Rb = c.doubleAndAddSelect(s2Bits[i], R0, R1) + R0 = c.Select(s2Bits[i], Rb, R0) + R1 = c.Select(s2Bits[i], R1, Rb) + + } + + // i = n-2 + // fixed-base + tmp1 := c.add(res1, &gm[n-2]) + res1 = c.Select(s1Bits[n-2], tmp1, res1) + // var-base + Rb = c.doubleAndAddSelect(s2Bits[n-2], R0, R1) + R0 = c.Select(s2Bits[n-2], Rb, R0) + R1 = c.Select(s2Bits[n-2], R1, Rb) + + // i = n-1 + // fixed-base + tmp1 = c.add(res1, &gm[n-1]) + res1 = c.Select(s1Bits[n-1], tmp1, res1) + // var-base + Rb = c.doubleAndAddSelect(s2Bits[n-1], R0, R1) + R0 = c.Select(s2Bits[n-1], Rb, R0) + + // i = 0 + // fixed-base + tmp1 = c.add(res1, c.Neg(g)) + res1 = c.Select(s1Bits[0], res1, tmp1) + // var-base + R0 = c.Select(s2Bits[0], R0, c.add(R0, c.Neg(p))) + + return c.add(res1, R0) +} diff --git a/std/algebra/emulated/sw_emulated/point_test.go b/std/algebra/emulated/sw_emulated/point_test.go new file mode 100644 index 0000000000..0f16765816 --- /dev/null +++ b/std/algebra/emulated/sw_emulated/point_test.go @@ -0,0 +1,776 @@ +package sw_emulated + +import ( + "crypto/elliptic" + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + fr_bls381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bn254" + fr_bn "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + fp_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + fr_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +var testCurve = ecc.BN254 + +type NegTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] +} + +func (c *NegTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.Neg(&c.P) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestNeg(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var yn fp_secp.Element + yn.Neg(&g.Y) + circuit := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](yn), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type AddTest[T, S emulated.FieldParams] struct { + P, Q, R AffinePoint[T] +} + +func (c *AddTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res1 := cr.add(&c.P, &c.Q) + res2 := cr.AddUnified(&c.P, &c.Q) + cr.AssertIsEqual(res1, &c.R) + cr.AssertIsEqual(res2, &c.R) + return nil +} + +func TestAdd(t *testing.T) { + assert := test.NewAssert(t) + var dJac, aJac secp256k1.G1Jac + g, _ := secp256k1.Generators() + dJac.Double(&g) + aJac.Set(&dJac). + AddAssign(&g) + var dAff, aAff secp256k1.G1Affine + dAff.FromJacobian(&dJac) + aAff.FromJacobian(&aJac) + circuit := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](aAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](aAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type DoubleTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] +} + +func (c *DoubleTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res1 := cr.double(&c.P) + res2 := cr.AddUnified(&c.P, &c.P) + cr.AssertIsEqual(res1, &c.Q) + cr.AssertIsEqual(res2, &c.Q) + return nil +} + +func TestDouble(t *testing.T) { + assert := test.NewAssert(t) + g, _ := secp256k1.Generators() + var dJac secp256k1.G1Jac + dJac.Double(&g) + var dAff secp256k1.G1Affine + dAff.FromJacobian(&dJac) + circuit := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type TripleTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] +} + +func (c *TripleTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.triple(&c.P) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestTriple(t *testing.T) { + assert := test.NewAssert(t) + g, _ := secp256k1.Generators() + var dJac secp256k1.G1Jac + dJac.Double(&g).AddAssign(&g) + var dAff secp256k1.G1Affine + dAff.FromJacobian(&dJac) + circuit := TripleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := TripleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type DoubleAndAddTest[T, S emulated.FieldParams] struct { + P, Q, R AffinePoint[T] +} + +func (c *DoubleAndAddTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.doubleAndAdd(&c.P, &c.Q) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestDoubleAndAdd(t *testing.T) { + assert := test.NewAssert(t) + var pJac, qJac, rJac secp256k1.G1Jac + g, _ := secp256k1.Generators() + pJac.Double(&g) + qJac.Set(&g) + rJac.Double(&pJac). + AddAssign(&qJac) + var pAff, qAff, rAff secp256k1.G1Affine + pAff.FromJacobian(&pJac) + qAff.FromJacobian(&qJac) + rAff.FromJacobian(&rJac) + circuit := DoubleAndAddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := DoubleAndAddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](pAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](pAff.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](qAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](qAff.Y), + }, + R: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](rAff.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](rAff.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type AddUnifiedEdgeCasesTest[T, S emulated.FieldParams] struct { + P, Q, R AffinePoint[T] +} + +func (c *AddUnifiedEdgeCasesTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.AddUnified(&c.P, &c.Q) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestAddUnifiedEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + var infinity bn254.G1Affine + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S, Sn bn254.G1Affine + S.ScalarMultiplication(&g, s) + Sn.Neg(&S) + + circuit := AddUnifiedEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{} + + // (0,0) + (0,0) == (0,0) + witness1 := AddUnifiedEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // S + (0,0) == S + witness2 := AddUnifiedEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) + S == S + witness3 := AddUnifiedEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + + // S + (-S) == (0,0) + witness4 := AddUnifiedEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Sn.X), + Y: emulated.ValueOf[emulated.BN254Fp](Sn.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // (-S) + S == (0,0) + witness5 := AddUnifiedEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Sn.X), + Y: emulated.ValueOf[emulated.BN254Fp](Sn.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulBaseTest[T, S emulated.FieldParams] struct { + Q AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulBaseTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.ScalarMulBase(&c.S) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestScalarMulBase(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulBaseTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulBase2(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bn254.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseTest[emulated.BN254Fp, emulated.BN254Fr]{} + witness := ScalarMulBaseTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulBase3(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bls12381.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bls12381.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{} + witness := ScalarMulBaseTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + S: emulated.ValueOf[emulated.BLS12381Fr](s), + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](S.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.ScalarMul(&c.P, &c.S) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestScalarMul(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S secp256k1.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul2(t *testing.T) { + assert := test.NewAssert(t) + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var res bn254.G1Affine + _, _, gen, _ := bn254.Generators() + res.ScalarMultiplication(&gen, s) + + circuit := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{} + witness := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](gen.X), + Y: emulated.ValueOf[emulated.BN254Fp](gen.Y), + }, + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](res.X), + Y: emulated.ValueOf[emulated.BN254Fp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul3(t *testing.T) { + assert := test.NewAssert(t) + var r fr_bls381.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var res bls12381.G1Affine + _, _, gen, _ := bls12381.Generators() + res.ScalarMultiplication(&gen, s) + + circuit := ScalarMulTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{} + witness := ScalarMulTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + S: emulated.ValueOf[emulated.BLS12381Fr](s), + P: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](gen.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](gen.Y), + }, + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](res.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul4(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + px, py := p256.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulTest[emulated.P256Fp, emulated.P256Fr]{} + witness := ScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](s), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p256.Params().Gx), + Y: emulated.ValueOf[emulated.P256Fp](p256.Params().Gy), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestScalarMul5(t *testing.T) { + assert := test.NewAssert(t) + p384 := elliptic.P384() + s, err := rand.Int(rand.Reader, p384.Params().N) + assert.NoError(err) + px, py := p384.ScalarBaseMult(s.Bytes()) + + circuit := ScalarMulTest[emulated.P384Fp, emulated.P384Fr]{} + witness := ScalarMulTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](s), + P: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](p384.Params().Gx), + Y: emulated.ValueOf[emulated.P384Fp](p384.Params().Gy), + }, + Q: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](px), + Y: emulated.ValueOf[emulated.P384Fp](py), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +type ScalarMulEdgeCasesTest[T, S emulated.FieldParams] struct { + P, R AffinePoint[T] + S emulated.Element[S] +} + +func (c *ScalarMulEdgeCasesTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.ScalarMul(&c.P, &c.S) + cr.AssertIsEqual(res, &c.R) + return nil +} + +func TestScalarMulEdgeCasesEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + var infinity bn254.G1Affine + _, _, g, _ := bn254.Generators() + var r fr_bn.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var S bn254.G1Affine + S.ScalarMultiplication(&g, s) + + circuit := ScalarMulEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{} + + // s * (0,0) == (0,0) + witness1 := ScalarMulEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](s), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // 0 * S == (0,0) + witness2 := ScalarMulEdgeCasesTest[emulated.BN254Fp, emulated.BN254Fr]{ + S: emulated.ValueOf[emulated.BN254Fr](new(big.Int)), + P: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](S.X), + Y: emulated.ValueOf[emulated.BN254Fp](S.Y), + }, + R: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +type IsOnCurveTest[T, S emulated.FieldParams] struct { + Q AffinePoint[T] +} + +func (c *IsOnCurveTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + cr.AssertIsOnCurve(&c.Q) + return nil +} + +func TestIsOnCurve(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var Q, infinity secp256k1.G1Affine + Q.ScalarMultiplication(&g, s) + + // Q=[s]G is on curve + circuit := IsOnCurveTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness1 := IsOnCurveTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](Q.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](Q.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) is on curve + witness2 := IsOnCurveTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](infinity.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestIsOnCurve2(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bn254.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var Q, infinity bn254.G1Affine + Q.ScalarMultiplication(&g, s) + + // Q=[s]G is on curve + circuit := IsOnCurveTest[emulated.BN254Fp, emulated.BN254Fr]{} + witness1 := IsOnCurveTest[emulated.BN254Fp, emulated.BN254Fr]{ + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Q.X), + Y: emulated.ValueOf[emulated.BN254Fp](Q.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) is on curve + witness2 := IsOnCurveTest[emulated.BN254Fp, emulated.BN254Fr]{ + Q: AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](infinity.X), + Y: emulated.ValueOf[emulated.BN254Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +func TestIsOnCurve3(t *testing.T) { + assert := test.NewAssert(t) + _, _, g, _ := bls12381.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + var Q, infinity bls12381.G1Affine + Q.ScalarMultiplication(&g, s) + + // Q=[s]G is on curve + circuit := IsOnCurveTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{} + witness1 := IsOnCurveTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](Q.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](Q.Y), + }, + } + err := test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // (0,0) is on curve + witness2 := IsOnCurveTest[emulated.BLS12381Fp, emulated.BLS12381Fr]{ + Q: AffinePoint[emulated.BLS12381Fp]{ + X: emulated.ValueOf[emulated.BLS12381Fp](infinity.X), + Y: emulated.ValueOf[emulated.BLS12381Fp](infinity.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) +} + +type JointScalarMulBaseTest[T, S emulated.FieldParams] struct { + P, Q AffinePoint[T] + S1, S2 emulated.Element[S] +} + +func (c *JointScalarMulBaseTest[T, S]) Define(api frontend.API) error { + cr, err := New[T, S](api, GetCurveParams[T]()) + if err != nil { + return err + } + res := cr.JointScalarMulBase(&c.P, &c.S2, &c.S1) + cr.AssertIsEqual(res, &c.Q) + return nil +} + +func TestJointScalarMulBase(t *testing.T) { + assert := test.NewAssert(t) + _, g := secp256k1.Generators() + var p secp256k1.G1Affine + p.Double(&g) + var r1, r2 fr_secp.Element + _, _ = r1.SetRandom() + _, _ = r2.SetRandom() + s1 := new(big.Int) + r1.BigInt(s1) + s2 := new(big.Int) + r2.BigInt(s2) + var Sj secp256k1.G1Jac + Sj.JointScalarMultiplicationBase(&p, s1, s2) + var S secp256k1.G1Affine + S.FromJacobian(&Sj) + + circuit := JointScalarMulBaseTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := JointScalarMulBaseTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S1: emulated.ValueOf[emulated.Secp256k1Fr](s1), + S2: emulated.ValueOf[emulated.Secp256k1Fr](s2), + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + Q: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/native/fields_bls12377/doc.go b/std/algebra/native/fields_bls12377/doc.go new file mode 100644 index 0000000000..69308f8280 --- /dev/null +++ b/std/algebra/native/fields_bls12377/doc.go @@ -0,0 +1,9 @@ +// Package fields_bls12377 implements the fields arithmetic of the Fp12 tower +// used to compute the pairing over the BLS12-377 curve. +// +// 𝔽p²[u] = 𝔽p/u²+5 +// 𝔽p⁶[v] = 𝔽p²/v³-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// +// Reference: https://eprint.iacr.org/2022/1162 +package fields_bls12377 diff --git a/std/algebra/fields_bls12377/e12.go b/std/algebra/native/fields_bls12377/e12.go similarity index 90% rename from std/algebra/fields_bls12377/e12.go rename to std/algebra/native/fields_bls12377/e12.go index 64ed8fc024..ed787978d5 100644 --- a/std/algebra/fields_bls12377/e12.go +++ b/std/algebra/native/fields_bls12377/e12.go @@ -20,7 +20,8 @@ import ( "math/big" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -262,6 +263,8 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12) *E12 { // Decompress Karabina's cyclotomic square result func (e *E12) Decompress(api frontend.API, x E12) *E12 { + // TODO: hadle the g3==0 case with MUX + var t [3]E2 var one E2 one.SetOne() @@ -344,51 +347,6 @@ func (e *E12) Conjugate(api frontend.API, e1 E12) *E12 { return e } -// MulBy034 multiplication by sparse element -func (e *E12) MulBy034(api frontend.API, c3, c4 E2) *E12 { - - var d E6 - - a := e.C0 - b := e.C1 - - b.MulBy01(api, c3, c4) - - c3.Add(api, E2{A0: 1, A1: 0}, c3) - d.Add(api, e.C0, e.C1) - d.MulBy01(api, c3, c4) - - e.C1.Add(api, a, b).Neg(api, e.C1).Add(api, e.C1, d) - e.C0.MulByNonResidue(api, b).Add(api, e.C0, a) - - return e -} - -// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) -func (e *E12) Mul034By034(api frontend.API, d3, d4, c3, c4 E2) *E12 { - var one, tmp, x3, x4, x04, x03, x34 E2 - one.SetOne() - x3.Mul(api, c3, d3) - x4.Mul(api, c4, d4) - x04.Add(api, c4, d4) - x03.Add(api, c3, d3) - tmp.Add(api, c3, c4) - x34.Add(api, d3, d4). - Mul(api, x34, tmp). - Sub(api, x34, x3). - Sub(api, x34, x4) - - e.C0.B0.MulByNonResidue(api, x4). - Add(api, e.C0.B0, one) - e.C0.B1 = x3 - e.C0.B2 = x34 - e.C1.B0 = x03 - e.C1.B1 = x04 - e.C1.B2.SetZero() - - return e -} - // Frobenius applies frob to an fp12 elmt func (e *E12) Frobenius(api frontend.API, e1 E12) *E12 { @@ -464,7 +422,7 @@ var InverseE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE12Hint) + solver.RegisterHint(InverseE12Hint) } // Inverse e12 elmts @@ -537,7 +495,7 @@ var DivE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE12Hint) + solver.RegisterHint(DivE12Hint) } // DivUnchecked e12 elmts @@ -587,35 +545,6 @@ func (e *E12) nSquareCompressed(api frontend.API, n int) { } } -// Expt compute e1**exponent, where the exponent is hardcoded -// This function is only used for the final expo of the pairing for bls12377, so the exponent is supposed to be hardcoded -// and on 64 bits. -func (e *E12) Expt(api frontend.API, e1 E12, exponent uint64) *E12 { - - res := e1 - - res.nSquareCompressed(api, 5) - res.Decompress(api, res) - res.Mul(api, res, e1) - x33 := res - res.nSquareCompressed(api, 7) - res.Decompress(api, res) - res.Mul(api, res, x33) - res.nSquareCompressed(api, 4) - res.Decompress(api, res) - res.Mul(api, res, e1) - res.CyclotomicSquare(api, res) - res.Mul(api, res, e1) - res.nSquareCompressed(api, 46) - res.Decompress(api, res) - res.Mul(api, res, e1) - - *e = res - - return e - -} - // Assign a value to self (witness assignment) func (e *E12) Assign(a *bls12377.E12) { e.C0.Assign(&a.C0) diff --git a/std/algebra/native/fields_bls12377/e12_pairing.go b/std/algebra/native/fields_bls12377/e12_pairing.go new file mode 100644 index 0000000000..32efdc3b1b --- /dev/null +++ b/std/algebra/native/fields_bls12377/e12_pairing.go @@ -0,0 +1,139 @@ +package fields_bls12377 + +import "github.com/consensys/gnark/frontend" + +// Square034 squares a sparse element in Fp12 +func (e *E12) Square034(api frontend.API, x E12) *E12 { + var c0, c2, c3 E6 + + c0.B0.Sub(api, x.C0.B0, x.C1.B0) + c0.B1.Neg(api, x.C1.B1) + c0.B2 = E2{0, 0} + + c3.B0 = x.C0.B0 + c3.B1.Neg(api, x.C1.B0) + c3.B2.Neg(api, x.C1.B1) + + c2.Mul0By01(api, x.C0.B0, x.C1.B0, x.C1.B1) + c3.MulBy01(api, c0.B0, c0.B1).Add(api, c3, c2) + e.C1.B0.Add(api, c2.B0, c2.B0) + e.C1.B1.Add(api, c2.B1, c2.B1) + + e.C0.B0 = c3.B0 + e.C0.B1.Add(api, c3.B1, c2.B0) + e.C0.B2.Add(api, c3.B2, c2.B1) + + return e +} + +// MulBy034 multiplication by sparse element +func (e *E12) MulBy034(api frontend.API, c3, c4 E2) *E12 { + + var d E6 + + a := e.C0 + b := e.C1 + + b.MulBy01(api, c3, c4) + + c3.Add(api, E2{A0: 1, A1: 0}, c3) + d.Add(api, e.C0, e.C1) + d.MulBy01(api, c3, c4) + + e.C1.Add(api, a, b).Neg(api, e.C1).Add(api, e.C1, d) + e.C0.MulByNonResidue(api, b).Add(api, e.C0, a) + + return e +} + +// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) +func Mul034By034(api frontend.API, d3, d4, c3, c4 E2) *[5]E2 { + var one, tmp, x00, x3, x4, x04, x03, x34 E2 + one.SetOne() + x3.Mul(api, c3, d3) + x4.Mul(api, c4, d4) + x04.Add(api, c4, d4) + x03.Add(api, c3, d3) + tmp.Add(api, c3, c4) + x34.Add(api, d3, d4). + Mul(api, x34, tmp). + Sub(api, x34, x3). + Sub(api, x34, x4) + + x00.MulByNonResidue(api, x4). + Add(api, x00, one) + + return &[5]E2{x00, x3, x34, x03, x04} +} + +func Mul01234By034(api frontend.API, x [5]E2, z3, z4 E2) *E12 { + var a, b, z1, z0, one E6 + var zero E2 + zero.SetZero() + one.SetOne() + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: zero} + a.Add(api, one, E6{B0: z3, B1: z4, B2: zero}) + b.Add(api, *c0, *c1) + a.Mul(api, a, b) + c := *Mul01By01(api, z3, z4, x[3], x[4]) + z1.Sub(api, a, *c0) + z1.Sub(api, z1, c) + z0.MulByNonResidue(api, c) + z0.Add(api, z0, *c0) + return &E12{ + C0: z0, + C1: z1, + } +} + +func (e *E12) MulBy01234(api frontend.API, x [5]E2) *E12 { + var a, b, c, z1, z0 E6 + var zero E2 + zero.SetZero() + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1 := &E6{B0: x[3], B1: x[4], B2: zero} + a.Add(api, e.C0, e.C1) + b.Add(api, *c0, *c1) + a.Mul(api, a, b) + b.Mul(api, e.C0, *c0) + c = e.C1 + c.MulBy01(api, x[3], x[4]) + z1.Sub(api, a, b) + z1.Sub(api, z1, c) + z0.MulByNonResidue(api, c) + z0.Add(api, z0, b) + + e.C0 = z0 + e.C1 = z1 + return e +} + +// Expt compute e1**exponent, where the exponent is hardcoded +// This function is only used for the final expo of the pairing for bls12377, so the exponent is supposed to be hardcoded +// and on 64 bits. +func (e *E12) Expt(api frontend.API, e1 E12, exponent uint64) *E12 { + + res := e1 + + res.nSquareCompressed(api, 5) + res.Decompress(api, res) + res.Mul(api, res, e1) + x33 := res + res.nSquareCompressed(api, 7) + res.Decompress(api, res) + res.Mul(api, res, x33) + res.nSquareCompressed(api, 4) + res.Decompress(api, res) + res.Mul(api, res, e1) + res.CyclotomicSquare(api, res) + res.Mul(api, res, e1) + res.nSquareCompressed(api, 46) + res.Decompress(api, res) + res.Mul(api, res, e1) + + *e = res + + return e + +} diff --git a/std/algebra/fields_bls12377/e12_test.go b/std/algebra/native/fields_bls12377/e12_test.go similarity index 100% rename from std/algebra/fields_bls12377/e12_test.go rename to std/algebra/native/fields_bls12377/e12_test.go diff --git a/std/algebra/fields_bls12377/e2.go b/std/algebra/native/fields_bls12377/e2.go similarity index 90% rename from std/algebra/fields_bls12377/e2.go rename to std/algebra/native/fields_bls12377/e2.go index 067e8242c4..fbb282fae8 100644 --- a/std/algebra/fields_bls12377/e2.go +++ b/std/algebra/native/fields_bls12377/e2.go @@ -21,7 +21,8 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -154,7 +155,7 @@ var InverseE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE2Hint) + solver.RegisterHint(InverseE2Hint) } // Inverse e2 elmts @@ -196,7 +197,7 @@ var DivE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE2Hint) + solver.RegisterHint(DivE2Hint) } // DivUnchecked e2 elmts @@ -240,3 +241,16 @@ func (e *E2) Select(api frontend.API, b frontend.Variable, r1, r2 E2) *E2 { return e } + +// Lookup2 implements two-bit lookup. It returns: +// - r1 if b1=0 and b2=0, +// - r2 if b1=0 and b2=1, +// - r3 if b1=1 and b2=0, +// - r3 if b1=1 and b2=1. +func (e *E2) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E2) *E2 { + + e.A0 = api.Lookup2(b1, b2, r1.A0, r2.A0, r3.A0, r4.A0) + e.A1 = api.Lookup2(b1, b2, r1.A1, r2.A1, r3.A1, r4.A1) + + return e +} diff --git a/std/algebra/fields_bls12377/e2_test.go b/std/algebra/native/fields_bls12377/e2_test.go similarity index 100% rename from std/algebra/fields_bls12377/e2_test.go rename to std/algebra/native/fields_bls12377/e2_test.go diff --git a/std/algebra/fields_bls12377/e6.go b/std/algebra/native/fields_bls12377/e6.go similarity index 90% rename from std/algebra/fields_bls12377/e6.go rename to std/algebra/native/fields_bls12377/e6.go index b8d15605e8..f11dbbd9ee 100644 --- a/std/algebra/fields_bls12377/e6.go +++ b/std/algebra/native/fields_bls12377/e6.go @@ -20,7 +20,8 @@ import ( "math/big" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -121,6 +122,21 @@ func (e *E6) Mul(api frontend.API, e1, e2 E6) *E6 { return e } +func (e *E6) Mul0By01(api frontend.API, a0, b0, b1 E2) *E6 { + + var t0, c1 E2 + + t0.Mul(api, a0, b0) + c1.Add(api, b0, b1) + c1.Mul(api, c1, a0).Sub(api, c1, t0) + + e.B0 = t0 + e.B1 = c1 + e.B2 = E2{0, 0} + + return e +} + // MulByFp2 creates a fp6elmt from fp elmts // icube is the imaginary elmt to the cube func (e *E6) MulByFp2(api frontend.API, e1 E6, e2 E2) *E6 { @@ -198,7 +214,7 @@ var DivE6Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE6Hint) + solver.RegisterHint(DivE6Hint) } // DivUnchecked e6 elmts @@ -246,7 +262,7 @@ var InverseE6Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE6Hint) + solver.RegisterHint(InverseE6Hint) } // Inverse e6 elmts @@ -324,3 +340,28 @@ func (e *E6) MulBy01(api frontend.API, c0, c1 E2) *E6 { return e } + +func Mul01By01(api frontend.API, c0, c1, d0, d1 E2) *E6 { + var a, b, t0, t1, t2, tmp E2 + + a.Mul(api, d0, c0) + b.Mul(api, d1, c1) + t0.Mul(api, c1, d1) + t0.Sub(api, t0, b) + t0.MulByNonResidue(api, t0) + t0.Add(api, t0, a) + t2.Mul(api, c0, d0) + t2.Sub(api, t2, a) + t2.Add(api, t2, b) + t1.Add(api, c0, c1) + tmp.Add(api, d0, d1) + t1.Mul(api, t1, tmp) + t1.Sub(api, t1, a) + t1.Sub(api, t1, b) + + return &E6{ + B0: t0, + B1: t1, + B2: t2, + } +} diff --git a/std/algebra/fields_bls12377/e6_test.go b/std/algebra/native/fields_bls12377/e6_test.go similarity index 100% rename from std/algebra/fields_bls12377/e6_test.go rename to std/algebra/native/fields_bls12377/e6_test.go diff --git a/std/algebra/native/fields_bls24315/doc.go b/std/algebra/native/fields_bls24315/doc.go new file mode 100644 index 0000000000..8b669480b7 --- /dev/null +++ b/std/algebra/native/fields_bls24315/doc.go @@ -0,0 +1,10 @@ +// Package fields_bls24315 implements the fields arithmetic of the Fp24 tower +// used to compute the pairing over the BLS24-315 curve. +// +// 𝔽p²[u] = 𝔽p/u²-13 +// 𝔽p⁴[v] = 𝔽p²/v²-u +// 𝔽p¹²[w] = 𝔽p⁴/w³-v +// 𝔽p²⁴[i] = 𝔽p¹²/i²-w +// +// Reference: https://eprint.iacr.org/2022/1162 +package fields_bls24315 diff --git a/std/algebra/fields_bls24315/e12.go b/std/algebra/native/fields_bls24315/e12.go similarity index 96% rename from std/algebra/fields_bls24315/e12.go rename to std/algebra/native/fields_bls24315/e12.go index 2a8583f3f8..9d62cf5d75 100644 --- a/std/algebra/fields_bls24315/e12.go +++ b/std/algebra/native/fields_bls24315/e12.go @@ -20,7 +20,7 @@ import ( "math/big" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -209,7 +209,7 @@ var InverseE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE12Hint) + solver.RegisterHint(InverseE12Hint) } // Inverse e12 elmts @@ -282,7 +282,7 @@ var DivE12Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE12Hint) + solver.RegisterHint(DivE12Hint) } // DivUnchecked e12 elmts @@ -353,6 +353,21 @@ func (e *E12) MulBy01(api frontend.API, c0, c1 E4) *E12 { return e } +func (e *E12) Mul0By01(api frontend.API, a0, b0, b1 E4) *E12 { + + var t0, c1 E4 + + t0.Mul(api, a0, b0) + c1.Add(api, b0, b1) + c1.Mul(api, c1, a0).Sub(api, c1, t0) + + e.C0 = t0 + e.C1 = c1 + e.C2 = E4{E2{0, 0}, E2{0, 0}} + + return e +} + // Assign a value to self (witness assignment) func (e *E12) Assign(a *bls24315.E12) { e.C0.Assign(&a.C0) diff --git a/std/algebra/fields_bls24315/e12_test.go b/std/algebra/native/fields_bls24315/e12_test.go similarity index 100% rename from std/algebra/fields_bls24315/e12_test.go rename to std/algebra/native/fields_bls24315/e12_test.go diff --git a/std/algebra/fields_bls24315/e2.go b/std/algebra/native/fields_bls24315/e2.go similarity index 91% rename from std/algebra/fields_bls24315/e2.go rename to std/algebra/native/fields_bls24315/e2.go index b850ef025b..441f575948 100644 --- a/std/algebra/fields_bls24315/e2.go +++ b/std/algebra/native/fields_bls24315/e2.go @@ -21,7 +21,7 @@ import ( bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/utils" ) @@ -160,7 +160,7 @@ var DivE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE2Hint) + solver.RegisterHint(DivE2Hint) } // DivUnchecked e2 elmts @@ -199,7 +199,7 @@ var InverseE2Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE2Hint) + solver.RegisterHint(InverseE2Hint) } // Inverse e2 elmts @@ -244,3 +244,16 @@ func (e *E2) Select(api frontend.API, b frontend.Variable, r1, r2 E2) *E2 { return e } + +// Lookup2 implements two-bit lookup. It returns: +// - r1 if b1=0 and b2=0, +// - r2 if b1=0 and b2=1, +// - r3 if b1=1 and b2=0, +// - r3 if b1=1 and b2=1. +func (e *E2) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E2) *E2 { + + e.A0 = api.Lookup2(b1, b2, r1.A0, r2.A0, r3.A0, r4.A0) + e.A1 = api.Lookup2(b1, b2, r1.A1, r2.A1, r3.A1, r4.A1) + + return e +} diff --git a/std/algebra/fields_bls24315/e24.go b/std/algebra/native/fields_bls24315/e24.go similarity index 91% rename from std/algebra/fields_bls24315/e24.go rename to std/algebra/native/fields_bls24315/e24.go index d70722cb70..ea21cf09c5 100644 --- a/std/algebra/fields_bls24315/e24.go +++ b/std/algebra/native/fields_bls24315/e24.go @@ -20,7 +20,7 @@ import ( "math/big" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -341,53 +341,6 @@ func (e *E24) Conjugate(api frontend.API, e1 E24) *E24 { return e } -// MulBy034 multiplication by sparse element -func (e *E24) MulBy034(api frontend.API, c3, c4 E4) *E24 { - - var d E12 - var one E4 - one.SetOne() - - a := e.D0 - b := e.D1 - - b.MulBy01(api, c3, c4) - - c3.Add(api, one, c3) - d.Add(api, e.D0, e.D1) - d.MulBy01(api, c3, c4) - - e.D1.Add(api, a, b).Neg(api, e.D1).Add(api, e.D1, d) - e.D0.MulByNonResidue(api, b).Add(api, e.D0, a) - - return e -} - -// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) -func (e *E24) Mul034By034(api frontend.API, d3, d4, c3, c4 E4) *E24 { - var one, tmp, x3, x4, x04, x03, x34 E4 - one.SetOne() - x3.Mul(api, c3, d3) - x4.Mul(api, c4, d4) - x04.Add(api, c4, d4) - x03.Add(api, c3, d3) - tmp.Add(api, c3, c4) - x34.Add(api, d3, d4). - Mul(api, x34, tmp). - Sub(api, x34, x3). - Sub(api, x34, x4) - - e.D0.C0.MulByNonResidue(api, x4). - Add(api, e.D0.C0, one) - e.D0.C1 = x3 - e.D0.C2 = x34 - e.D1.C0 = x03 - e.D1.C1 = x04 - e.D1.C2.SetZero() - - return e -} - var InverseE24Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { var a, c bls24315.E24 @@ -447,7 +400,7 @@ var InverseE24Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE24Hint) + solver.RegisterHint(InverseE24Hint) } // Inverse e24 elmts @@ -557,7 +510,7 @@ var DivE24Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE24Hint) + solver.RegisterHint(DivE24Hint) } // DivUnchecked e24 elmts @@ -595,31 +548,6 @@ func (e *E24) nSquare(api frontend.API, n int) { } } -// Expt compute e1**exponent, where the exponent is hardcoded -// This function is only used for the final expo of the pairing for bls24315, so the exponent is supposed to be hardcoded and on 32 bits. -func (e *E24) Expt(api frontend.API, x E24, exponent uint64) *E24 { - - xInv := E24{} - res := x - xInv.Conjugate(api, x) - - res.nSquare(api, 2) - res.Mul(api, res, xInv) - res.nSquareCompressed(api, 8) - res.Decompress(api, res) - res.Mul(api, res, xInv) - res.nSquare(api, 2) - res.Mul(api, res, x) - res.nSquareCompressed(api, 20) - res.Decompress(api, res) - res.Mul(api, res, xInv) - res.Conjugate(api, res) - - *e = res - - return e -} - // AssertIsEqual constraint self to be equal to other into the given constraint system func (e *E24) AssertIsEqual(api frontend.API, other E24) { e.D0.AssertIsEqual(api, other.D0) diff --git a/std/algebra/native/fields_bls24315/e24_pairing.go b/std/algebra/native/fields_bls24315/e24_pairing.go new file mode 100644 index 0000000000..1a43c8df85 --- /dev/null +++ b/std/algebra/native/fields_bls24315/e24_pairing.go @@ -0,0 +1,94 @@ +package fields_bls24315 + +import "github.com/consensys/gnark/frontend" + +// Square034 squares a sparse element in Fp24 +func (e *E24) Square034(api frontend.API, x E24) *E24 { + var c0, c2, c3 E12 + + c0.C0.Sub(api, x.D0.C0, x.D1.C0) + c0.C1.Neg(api, x.D1.C1) + c0.C2 = E4{E2{0, 0}, E2{0, 0}} + + c3.C0 = x.D0.C0 + c3.C1.Neg(api, x.D1.C0) + c3.C2.Neg(api, x.D1.C1) + + c2.Mul0By01(api, x.D0.C0, x.D1.C0, x.D1.C1) + c3.MulBy01(api, c0.C0, c0.C1).Add(api, c3, c2) + e.D1.C0.Add(api, c2.C0, c2.C0) + e.D1.C1.Add(api, c2.C1, c2.C1) + + e.D0.C0 = c3.C0 + e.D0.C1.Add(api, c3.C1, c2.C0) + e.D0.C2.Add(api, c3.C2, c2.C1) + + return e +} + +// MulBy034 multiplication by sparse element +func (e *E24) MulBy034(api frontend.API, c3, c4 E4) *E24 { + + var d E12 + var one E4 + one.SetOne() + + a := e.D0 + b := e.D1 + + b.MulBy01(api, c3, c4) + + c3.Add(api, one, c3) + d.Add(api, e.D0, e.D1) + d.MulBy01(api, c3, c4) + + e.D1.Add(api, a, b).Neg(api, e.D1).Add(api, e.D1, d) + e.D0.MulByNonResidue(api, b).Add(api, e.D0, a) + + return e +} + +// Mul034By034 multiplication of sparse element (1,0,0,c3,c4,0) by sparse element (1,0,0,d3,d4,0) +func Mul034By034(api frontend.API, d3, d4, c3, c4 E4) *[5]E4 { + var one, tmp, x00, x3, x4, x04, x03, x34 E4 + one.SetOne() + x3.Mul(api, c3, d3) + x4.Mul(api, c4, d4) + x04.Add(api, c4, d4) + x03.Add(api, c3, d3) + tmp.Add(api, c3, c4) + x34.Add(api, d3, d4). + Mul(api, x34, tmp). + Sub(api, x34, x3). + Sub(api, x34, x4) + + x00.MulByNonResidue(api, x4). + Add(api, x00, one) + + return &[5]E4{x00, x3, x34, x03, x04} +} + +// Expt compute e1**exponent, where the exponent is hardcoded +// This function is only used for the final expo of the pairing for bls24315, so the exponent is supposed to be hardcoded and on 32 bits. +func (e *E24) Expt(api frontend.API, x E24, exponent uint64) *E24 { + + xInv := E24{} + res := x + xInv.Conjugate(api, x) + + res.nSquare(api, 2) + res.Mul(api, res, xInv) + res.nSquareCompressed(api, 8) + res.Decompress(api, res) + res.Mul(api, res, xInv) + res.nSquare(api, 2) + res.Mul(api, res, x) + res.nSquareCompressed(api, 20) + res.Decompress(api, res) + res.Mul(api, res, xInv) + res.Conjugate(api, res) + + *e = res + + return e +} diff --git a/std/algebra/fields_bls24315/e24_test.go b/std/algebra/native/fields_bls24315/e24_test.go similarity index 100% rename from std/algebra/fields_bls24315/e24_test.go rename to std/algebra/native/fields_bls24315/e24_test.go diff --git a/std/algebra/fields_bls24315/e2_test.go b/std/algebra/native/fields_bls24315/e2_test.go similarity index 100% rename from std/algebra/fields_bls24315/e2_test.go rename to std/algebra/native/fields_bls24315/e2_test.go diff --git a/std/algebra/fields_bls24315/e4.go b/std/algebra/native/fields_bls24315/e4.go similarity index 91% rename from std/algebra/fields_bls24315/e4.go rename to std/algebra/native/fields_bls24315/e4.go index 57abb6ccf8..f6fc78c405 100644 --- a/std/algebra/fields_bls24315/e4.go +++ b/std/algebra/native/fields_bls24315/e4.go @@ -20,7 +20,7 @@ import ( "math/big" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -165,7 +165,7 @@ var DivE4Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(DivE4Hint) + solver.RegisterHint(DivE4Hint) } // DivUnchecked e4 elmts @@ -208,7 +208,7 @@ var InverseE4Hint = func(_ *big.Int, inputs []*big.Int, res []*big.Int) error { } func init() { - hint.Register(InverseE4Hint) + solver.RegisterHint(InverseE4Hint) } // Inverse e4 elmts @@ -253,3 +253,16 @@ func (e *E4) Select(api frontend.API, b frontend.Variable, r1, r2 E4) *E4 { return e } + +// Lookup2 implements two-bit lookup. It returns: +// - r1 if b1=0 and b2=0, +// - r2 if b1=0 and b2=1, +// - r3 if b1=1 and b2=0, +// - r3 if b1=1 and b2=1. +func (e *E4) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E4) *E4 { + + e.B0.Lookup2(api, b1, b2, r1.B0, r2.B0, r3.B0, r4.B0) + e.B1.Lookup2(api, b1, b2, r1.B1, r2.B1, r3.B1, r4.B1) + + return e +} diff --git a/std/algebra/fields_bls24315/e4_test.go b/std/algebra/native/fields_bls24315/e4_test.go similarity index 100% rename from std/algebra/fields_bls24315/e4_test.go rename to std/algebra/native/fields_bls24315/e4_test.go diff --git a/std/algebra/native/sw_bls12377/doc.go b/std/algebra/native/sw_bls12377/doc.go new file mode 100644 index 0000000000..3888dba835 --- /dev/null +++ b/std/algebra/native/sw_bls12377/doc.go @@ -0,0 +1,8 @@ +// Package sw_bls12377 implements the arithmetics of G1, G2 and the pairing +// computation on BLS12-377 as a SNARK circuit over BW6-761. These two curves +// form a 2-chain so the operations use native field arithmetic. +// +// References: +// BW6-761: https://eprint.iacr.org/2020/351 +// Pairings in R1CS: https://eprint.iacr.org/2022/1162 +package sw_bls12377 diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/native/sw_bls12377/g1.go similarity index 92% rename from std/algebra/sw_bls12377/g1.go rename to std/algebra/native/sw_bls12377/g1.go index 303315a833..546fd19857 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/native/sw_bls12377/g1.go @@ -22,7 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -228,7 +229,7 @@ var DecomposeScalarG1 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG1) + solver.RegisterHint(DecomposeScalarG1) } // varScalarMul sets P = [s] Q and returns P. @@ -299,20 +300,18 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -451,3 +450,36 @@ func (p *G1Affine) DoubleAndAdd(api frontend.API, p1, p2 *G1Affine) *G1Affine { return p } + +// ScalarMulBase computes s * g1 and returns it, where g1 is the fixed generator. It doesn't modify s. +func (P *G1Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G1Affine { + + points := getCurvePoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G1Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X = api.Lookup2(sBits[1], sBits[2], points.G1x, points.G1m[0][0], points.G1m[1][0], points.G1m[2][0]) + res.Y = api.Lookup2(sBits[1], sBits[2], points.G1y, points.G1m[0][1], points.G1m[1][1], points.G1m[2][1]) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G1Affine{points.G1m[i][0], points.G1m[i][1]}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G1Affine{points.G1x, points.G1y}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls12377/g1_test.go b/std/algebra/native/sw_bls12377/g1_test.go similarity index 93% rename from std/algebra/sw_bls12377/g1_test.go rename to std/algebra/native/sw_bls12377/g1_test.go index 2a0a83a1a5..55e5c524bc 100644 --- a/std/algebra/sw_bls12377/g1_test.go +++ b/std/algebra/native/sw_bls12377/g1_test.go @@ -369,6 +369,37 @@ func TestScalarMulG1(t *testing.T) { assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) } +type g1varScalarMulBase struct { + C G1Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g1varScalarMulBase) Define(api frontend.API) error { + expected := G1Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG1(t *testing.T) { + var c bls12377.G1Affine + gJac, _, _, _ := bls12377.Generators() + + // create the cs + var circuit, witness g1varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) +} + func randomPointG1() bls12377.G1Jac { p1, _, _, _ := bls12377.Generators() diff --git a/std/algebra/sw_bls12377/g2.go b/std/algebra/native/sw_bls12377/g2.go similarity index 87% rename from std/algebra/sw_bls12377/g2.go rename to std/algebra/native/sw_bls12377/g2.go index fcfcdbcf20..64fdf9e87f 100644 --- a/std/algebra/sw_bls12377/g2.go +++ b/std/algebra/native/sw_bls12377/g2.go @@ -21,9 +21,10 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend/hint" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" ) // G2Jac point in Jacobian coords @@ -243,7 +244,7 @@ var DecomposeScalarG2 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG2) + solver.RegisterHint(DecomposeScalarG2) } // varScalarMul sets P = [s] Q and returns P. @@ -314,20 +315,18 @@ func (P *G2Affine) varScalarMul(api frontend.API, Q G2Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -471,3 +470,68 @@ func (p *G2Affine) DoubleAndAdd(api frontend.API, p1, p2 *G2Affine) *G2Affine { return p } + +// ScalarMulBase computes s * g2 and returns it, where g2 is the fixed generator. It doesn't modify s. +func (P *G2Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G2Affine { + + points := getTwistPoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G2Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X.Lookup2(api, sBits[1], sBits[2], + fields_bls12377.E2{ + A0: points.G2x[0], + A1: points.G2x[1]}, + fields_bls12377.E2{ + A0: points.G2m[0][0], + A1: points.G2m[0][1]}, + fields_bls12377.E2{ + A0: points.G2m[1][0], + A1: points.G2m[1][1]}, + fields_bls12377.E2{ + A0: points.G2m[2][0], + A1: points.G2m[2][1]}) + res.Y.Lookup2(api, sBits[1], sBits[2], + fields_bls12377.E2{ + A0: points.G2y[0], + A1: points.G2y[1]}, + fields_bls12377.E2{ + A0: points.G2m[0][2], + A1: points.G2m[0][3]}, + fields_bls12377.E2{ + A0: points.G2m[1][2], + A1: points.G2m[1][3]}, + fields_bls12377.E2{ + A0: points.G2m[2][2], + A1: points.G2m[2][3]}) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G2Affine{ + fields_bls12377.E2{ + A0: points.G2m[i][0], + A1: points.G2m[i][1]}, + fields_bls12377.E2{ + A0: points.G2m[i][2], + A1: points.G2m[i][3]}}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G2Affine{ + fields_bls12377.E2{A0: points.G2x[0], A1: points.G2x[1]}, + fields_bls12377.E2{A0: points.G2y[0], A1: points.G2y[1]}}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls12377/g2_test.go b/std/algebra/native/sw_bls12377/g2_test.go similarity index 93% rename from std/algebra/sw_bls12377/g2_test.go rename to std/algebra/native/sw_bls12377/g2_test.go index 534a58b54c..2bca05f714 100644 --- a/std/algebra/sw_bls12377/g2_test.go +++ b/std/algebra/native/sw_bls12377/g2_test.go @@ -374,6 +374,38 @@ func TestScalarMulG2(t *testing.T) { assert := test.NewAssert(t) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) } + +type g2varScalarMulBase struct { + C G2Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g2varScalarMulBase) Define(api frontend.API) error { + expected := G2Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG2(t *testing.T) { + var c bls12377.G2Affine + _, gJac, _, _ := bls12377.Generators() + + // create the cs + var circuit, witness g2varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) +} + func randomPointG2() bls12377.G2Jac { _, p2, _, _ := bls12377.Generators() diff --git a/std/algebra/sw_bls12377/inner.go b/std/algebra/native/sw_bls12377/inner.go similarity index 66% rename from std/algebra/sw_bls12377/inner.go rename to std/algebra/native/sw_bls12377/inner.go index 7843ac8f17..3396afa7d1 100644 --- a/std/algebra/sw_bls12377/inner.go +++ b/std/algebra/native/sw_bls12377/inner.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/frontend" ) @@ -62,3 +63,44 @@ func getInnerCurveConfig(outerCurveScalarField *big.Int) *innerConfig { return &innerConfigBW6_761 } + +var ( + computedCurveTable [][2]*big.Int + computedTwistTable [][4]*big.Int +) + +func init() { + computedCurveTable = computeCurveTable() + computedTwistTable = computeTwistTable() +} + +type curvePoints struct { + G1x *big.Int // base point x + G1y *big.Int // base point y + G1m [][2]*big.Int // m*base points (x,y) +} + +func getCurvePoints() curvePoints { + _, _, g1aff, _ := bls12377.Generators() + return curvePoints{ + G1x: g1aff.X.BigInt(new(big.Int)), + G1y: g1aff.Y.BigInt(new(big.Int)), + G1m: computedCurveTable, + } +} + +type twistPoints struct { + G2x [2]*big.Int // base point x ∈ E2 + G2y [2]*big.Int // base point y ∈ E2 + G2m [][4]*big.Int // m*base points (x,y) +} + +func getTwistPoints() twistPoints { + _, _, _, g2aff := bls12377.Generators() + return twistPoints{ + G2x: [2]*big.Int{g2aff.X.A0.BigInt(new(big.Int)), g2aff.X.A1.BigInt(new(big.Int))}, + G2y: [2]*big.Int{g2aff.Y.A0.BigInt(new(big.Int)), g2aff.Y.A1.BigInt(new(big.Int))}, + G2m: computedTwistTable, + } + +} diff --git a/std/algebra/native/sw_bls12377/inner_compute.go b/std/algebra/native/sw_bls12377/inner_compute.go new file mode 100644 index 0000000000..1bf12a2f5c --- /dev/null +++ b/std/algebra/native/sw_bls12377/inner_compute.go @@ -0,0 +1,59 @@ +package sw_bls12377 + +import ( + "math/big" + + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" +) + +func computeCurveTable() [][2]*big.Int { + G1jac, _, _, _ := bls12377.Generators() + table := make([][2]*big.Int, 253) + tmp := new(bls12377.G1Jac).Set(&G1jac) + aff := new(bls12377.G1Affine) + jac := new(bls12377.G1Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} + +func computeTwistTable() [][4]*big.Int { + _, G2jac, _, _ := bls12377.Generators() + table := make([][4]*big.Int, 253) + tmp := new(bls12377.G2Jac).Set(&G2jac) + aff := new(bls12377.G2Affine) + jac := new(bls12377.G2Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))} + } + } + return table[:] +} diff --git a/std/algebra/native/sw_bls12377/pairing.go b/std/algebra/native/sw_bls12377/pairing.go new file mode 100644 index 0000000000..d6d8fcde99 --- /dev/null +++ b/std/algebra/native/sw_bls12377/pairing.go @@ -0,0 +1,480 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sw_bls12377 + +import ( + "errors" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" +) + +// GT target group of the pairing +type GT = fields_bls12377.E12 + +// binary decomposition of x₀=9586122913090633729 little endian +var loopCounter = [64]int8{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1} + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 + R0(x/y) + R1(1/y) = 0 instead of R0'*y + R1'*x + R2' = 0 This +// makes the multiplication by lines (MulBy034) and between lines (Mul034By034) +// circuit-efficient. +type lineEvaluation struct { + R0, R1 fields_bls12377.E2 +} + +// MillerLoop computes the product of n miller loops (n can be 1) +// ∏ᵢ { fᵢ_{x₀,Q}(P) } +func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return GT{}, errors.New("invalid inputs sizes") + } + + var res GT + res.SetOne() + var prodLines [5]fields_bls12377.E2 + + var l1, l2 lineEvaluation + Qacc := make([]G2Affine, n) + yInv := make([]frontend.Variable, n) + xOverY := make([]frontend.Variable, n) + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + // x=0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000000 + // TODO: point P=(x,0) should be ruled out + yInv[k] = api.DivUnchecked(1, P[k].Y) + xOverY[k] = api.Mul(P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + // i = 62, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line to res) + Qacc[0], l1 = doubleStep(api, &Qacc[0]) + // line evaluation at P[0] + res.C1.B0.MulByFp(api, l1.R0, xOverY[0]) + res.C1.B1.MulByFp(api, l1.R1, yInv[0]) + + if n >= 2 { + // k = 1, separately to avoid MulBy034 (res × ℓ) + // (res is also a line at this point, so we use Mul034By034 ℓ × ℓ) + Qacc[1], l1 = doubleStep(api, &Qacc[1]) + + // line evaluation at P[1] + l1.R0.MulByFp(api, l1.R0, xOverY[1]) + l1.R1.MulByFp(api, l1.R1, yInv[1]) + + // ℓ × res + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, res.C1.B0, res.C1.B1) + res.C0.B0 = prodLines[0] + res.C0.B1 = prodLines[1] + res.C0.B2 = prodLines[2] + res.C1.B0 = prodLines[3] + res.C1.B1 = prodLines[4] + + } + + if n >= 3 { + // k = 2, separately to avoid MulBy034 (res × ℓ) + // (res has a zero E2 element, so we use Mul01234By034) + Qacc[2], l1 = doubleStep(api, &Qacc[2]) + + // line evaluation at P[1] + l1.R0.MulByFp(api, l1.R0, xOverY[2]) + l1.R1.MulByFp(api, l1.R1, yInv[2]) + + // ℓ × res + res = *fields_bls12377.Mul01234By034(api, prodLines, l1.R0, l1.R1) + + // k >= 3 + for k := 3; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + } + + // i = 61, separately to use a special E12 Square + // k = 0 + // Qacc[0] ← 2Qacc[0] and l1 the tangent ℓ passing 2Qacc[0] + Qacc[0], l1 = doubleStep(api, &Qacc[0]) + // line evaluation at P[0] + l1.R0.MulByFp(api, l1.R0, xOverY[0]) + l1.R1.MulByFp(api, l1.R1, yInv[0]) + + if n == 1 { + res.Square034(api, res) + prodLines[0] = res.C0.B0 + prodLines[1] = res.C0.B1 + prodLines[2] = res.C0.B2 + prodLines[3] = res.C1.B0 + prodLines[4] = res.C1.B1 + // ℓ × res + res = *fields_bls12377.Mul01234By034(api, prodLines, l1.R0, l1.R1) + + } else { + res.Square(api, res) + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + } + + for k := 1; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + + for i := 60; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res.Square(api, res) + + if loopCounter[i] == 0 { + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + continue + } + + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = doubleAndAddStep(api, &Qacc[k], &Q[k]) + + // lines evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, l2.R0, l2.R1) + // (ℓ × ℓ) × res + res.MulBy01234(api, prodLines) + } + } + + // i = 0 + res.Square(api, res) + for k := 0; k < n; k++ { + // l1 line through Qacc[k] and Q[k] + // l2 line through Qacc[k]+Q[k] and Qacc[k] + l1, l2 = linesCompute(api, &Qacc[k], &Q[k]) + + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × ℓ + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, l2.R0, l2.R1) + // (ℓ × ℓ) × res + res.MulBy01234(api, prodLines) + } + + return res, nil +} + +// FinalExponentiation computes the exponentiation e1ᵈ +// where d = (p¹²-1)/r = (p¹²-1)/Φ₁₂(p) ⋅ Φ₁₂(p)/r = (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// we use instead d=s ⋅ (p⁶-1)(p²+1)(p⁴ - p² +1)/r +// where s is the cofactor 3 (Hayashida et al.) +func FinalExponentiation(api frontend.API, e1 GT) GT { + const genT = 9586122913090633729 + + result := e1 + + // https://eprint.iacr.org/2016/130.pdf + var t [3]GT + + // easy part + // (p⁶-1)(p²+1) + t[0].Conjugate(api, result) + t[0].DivUnchecked(api, t[0], result) + result.FrobeniusSquare(api, t[0]). + Mul(api, result, t[0]) + + // hard part (up to permutation) + // Daiki Hayashida and Kenichiro Hayasaka + // and Tadanori Teruya + // https://eprint.iacr.org/2020/875.pdf + t[0].CyclotomicSquare(api, result) + t[1].Expt(api, result, genT) + t[2].Conjugate(api, result) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Conjugate(api, t[1]) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Frobenius(api, t[1]) + t[1].Mul(api, t[1], t[2]) + result.Mul(api, result, t[0]) + t[0].Expt(api, t[1], genT) + t[2].Expt(api, t[0], genT) + t[0].FrobeniusSquare(api, t[1]) + t[1].Conjugate(api, t[1]) + t[1].Mul(api, t[1], t[2]) + t[1].Mul(api, t[1], t[0]) + result.Mul(api, result, t[1]) + + return result +} + +// Pair calculates the reduced pairing for a set of points +// ∏ᵢ e(Pᵢ, Qᵢ). +// +// This function doesn't check that the inputs are in the correct subgroup. See IsInSubGroup. +func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + f, err := MillerLoop(api, P, Q) + if err != nil { + return GT{}, err + } + return FinalExponentiation(api, f), nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3, x4, y4 fields_bls12377.E2 + var line1, line2 lineEvaluation + var p G2Affine + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute x4 = lambda2**2-x1-x3 + x4.Square(api, l2). + Sub(api, x4, p1.X). + Sub(api, x4, x3) + + // compute y4 = lambda2*(x1 - x4)-y1 + y4.Sub(api, p1.X, x4). + Mul(api, l2, y4). + Sub(api, y4, p1.Y) + + p.X = x4 + p.Y = y4 + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return p, line1, line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleStep(api frontend.API, p1 *G2Affine) (G2Affine, lineEvaluation) { + + var n, d, l, xr, yr fields_bls12377.E2 + var p G2Affine + var line lineEvaluation + + // lambda = 3*p1.x**2/2*p.y + n.Square(api, p1.X).MulByFp(api, n, 3) + d.MulByFp(api, p1.Y, 2) + l.DivUnchecked(api, n, d) + + // xr = lambda**2-2*p1.x + xr.Square(api, l). + Sub(api, xr, p1.X). + Sub(api, xr, p1.X) + + // yr = lambda*(p.x-xr)-p.y + yr.Sub(api, p1.X, xr). + Mul(api, l, yr). + Sub(api, yr, p1.Y) + + p.X = xr + p.Y = yr + + line.R0.Neg(api, l) + line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) + + return p, line + +} + +// linesCompute computes the lines that goes through p1 and p2, and (p1+p2) and p1 but does not compute 2p1+p2 +func linesCompute(api frontend.API, p1, p2 *G2Affine) (lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3 fields_bls12377.E2 + var line1, line2 lineEvaluation + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return line1, line2 +} + +// ---------------------------- +// Fixed-argument pairing +// ---------------------------- +// +// The second argument Q is the fixed canonical generator of G2. +// +// Q.X.A0 = 0x18480be71c785fec89630a2a3841d01c565f071203e50317ea501f557db6b9b71889f52bb53540274e3e48f7c005196 +// Q.X.A1 = 0xea6040e700403170dc5a51b1b140d5532777ee6651cecbe7223ece0799c9de5cf89984bff76fe6b26bfefa6ea16afe +// Q.Y.A0 = 0x690d665d446f7bd960736bcbb2efb4de03ed7274b49a58e458c282f832d204f2cf88886d8c7c2ef094094409fd4ddf +// Q.Y.A1 = 0xf8169fd28355189e549da3151a70aa61ef11ac3d591bf12463b01acee304c24279b83f5e52270bd9a1cdd185eb8f93 + +// MillerLoopFixed computes the single Miller loop +// fᵢ_{u,g2}(P), where g2 is fixed. +func MillerLoopFixedQ(api frontend.API, P G1Affine) (GT, error) { + + var res GT + res.SetOne() + var prodLines [5]fields_bls12377.E2 + + var l1, l2 lineEvaluation + var yInv, xOverY frontend.Variable + yInv = api.DivUnchecked(1, P.Y) + xOverY = api.Mul(P.X, yInv) + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + // i = 62, separately to avoid an E12 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line(P) to res) + res.C1.B0.MulByFp(api, precomputedLines[0][62], xOverY) + res.C1.B1.MulByFp(api, precomputedLines[1][62], yInv) + + // i = 61, separately to use a special E12 Square + res.Square034(api, res) + prodLines[0] = res.C0.B0 + prodLines[1] = res.C0.B1 + prodLines[2] = res.C0.B2 + prodLines[3] = res.C1.B0 + prodLines[4] = res.C1.B1 + // line evaluation at P + l1.R0.MulByFp(api, precomputedLines[0][61], xOverY) + l1.R1.MulByFp(api, precomputedLines[1][61], yInv) + // ℓ × res + res = *fields_bls12377.Mul01234By034(api, prodLines, l1.R0, l1.R1) + + for i := 60; i >= 0; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res.Square(api, res) + + if loopCounter[i] == 0 { + // line evaluation at P + l1.R0.MulByFp(api, precomputedLines[0][i], xOverY) + l1.R1.MulByFp(api, precomputedLines[1][i], yInv) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + continue + + } + + // lines evaluation at P + l1.R0.MulByFp(api, precomputedLines[0][i], xOverY) + l1.R1.MulByFp(api, precomputedLines[1][i], yInv) + l2.R0.MulByFp(api, precomputedLines[2][i], xOverY) + l2.R1.MulByFp(api, precomputedLines[3][i], yInv) + + // ℓ × ℓ + prodLines = *fields_bls12377.Mul034By034(api, l1.R0, l1.R1, l2.R0, l2.R1) + // (ℓ × ℓ) × res + res.MulBy01234(api, prodLines) + } + + return res, nil +} + +// PairFixedQ calculates the reduced pairing for a set of points +// e(P, g2), where g2 is fixed. +// +// This function doesn't check that the inputs are in the correct subgroups. +func PairFixedQ(api frontend.API, P G1Affine) (GT, error) { + f, err := MillerLoopFixedQ(api, P) + if err != nil { + return GT{}, err + } + return FinalExponentiation(api, f), nil +} diff --git a/std/algebra/sw_bls12377/pairing_test.go b/std/algebra/native/sw_bls12377/pairing_test.go similarity index 87% rename from std/algebra/sw_bls12377/pairing_test.go rename to std/algebra/native/sw_bls12377/pairing_test.go index bf09b6efad..a42b817290 100644 --- a/std/algebra/sw_bls12377/pairing_test.go +++ b/std/algebra/native/sw_bls12377/pairing_test.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" "github.com/consensys/gnark/test" ) @@ -126,6 +126,37 @@ func TestTriplePairingBLS377(t *testing.T) { } +type pairingFixedBLS377 struct { + P G1Affine `gnark:",public"` + pairingRes bls12377.GT +} + +func (circuit *pairingFixedBLS377) Define(api frontend.API) error { + + pairingRes, _ := PairFixedQ(api, circuit.P) + + mustbeEq(api, pairingRes, &circuit.pairingRes) + + return nil +} + +func TestPairingFixedBLS377(t *testing.T) { + + // pairing test data + P, _, _, pairingRes := pairingData() + + // create cs + var circuit, witness pairingFixedBLS377 + circuit.pairingRes = pairingRes + + // assign values to witness + witness.P.Assign(&P) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) + +} + // utils func pairingData() (P bls12377.G1Affine, Q bls12377.G2Affine, milRes, pairingRes bls12377.GT) { _, _, P, Q = bls12377.Generators() diff --git a/std/algebra/native/sw_bls12377/precomputations.go b/std/algebra/native/sw_bls12377/precomputations.go new file mode 100644 index 0000000000..9616f99a58 --- /dev/null +++ b/std/algebra/native/sw_bls12377/precomputations.go @@ -0,0 +1,308 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sw_bls12377 + +import "github.com/consensys/gnark/std/algebra/native/fields_bls12377" + +// precomputed lines going through Q and multiples of Q +// where Q is the fixed canonical generator of G2 +// +// Q.X.A0 = 0x18480be71c785fec89630a2a3841d01c565f071203e50317ea501f557db6b9b71889f52bb53540274e3e48f7c005196 +// Q.X.A1 = 0xea6040e700403170dc5a51b1b140d5532777ee6651cecbe7223ece0799c9de5cf89984bff76fe6b26bfefa6ea16afe +// Q.Y.A0 = 0x690d665d446f7bd960736bcbb2efb4de03ed7274b49a58e458c282f832d204f2cf88886d8c7c2ef094094409fd4ddf +// Q.Y.A1 = 0xf8169fd28355189e549da3151a70aa61ef11ac3d591bf12463b01acee304c24279b83f5e52270bd9a1cdd185eb8f93 + +var precomputedLines [4][63]fields_bls12377.E2 + +func init() { + precomputedLines[0][62].A0 = "140988482040386324248198112209067307758466420200971161148224569905129921720188825131108643620349828443596943631004" + precomputedLines[0][62].A1 = "38285511025528108185446422111142282591962051133412369448651977919088501218747243552500995429459187749235230706095" + precomputedLines[1][62].A0 = "240155915094625874419934231512470237700285240367926454007160179123226044819661538899285088085163427741534561817913" + precomputedLines[1][62].A1 = "223115033679672110617671091484598616458749264033732645675787392173820176087722042434509922701267792140213636611231" + precomputedLines[0][61].A0 = "200434958806057349802162752596109649646825612161451945981398752293313419856233843110613679972690184654975364417821" + precomputedLines[0][61].A1 = "58782839865014741808685925119247514734041203459952950234417602068856574067453950006450203782382677294229715822669" + precomputedLines[1][61].A0 = "221104783854457852217149670421014574568257761015321966015802821666781409364260524164573031195935814032318596063821" + precomputedLines[1][61].A1 = "215808263269556315801392787356915665282147289752591160048025718955504747606960075076763470727775607592602221010992" + precomputedLines[0][60].A0 = "76041212684332796496544503282559155730743386924998067938629401863728086863954381692679104043606708834502612121469" + precomputedLines[0][60].A1 = "100035609662702304463030411506002198944244805269834337236123382114910737217630245466671270103781538783706579105866" + precomputedLines[1][60].A0 = "74764375961719330445928565332016106267563990543199553848931319700934770734164802646565305287411909533142305235213" + precomputedLines[1][60].A1 = "158012842767786331849495340937588038244489118968577628545983571742104830293331067767096049488233731677510082972964" + precomputedLines[0][59].A0 = "193748126954112966035895420477410839340013406543588071113075172554110443264930976028399734543079189662772415617492" + precomputedLines[0][59].A1 = "103044909642891374077496405785099025851114547191286132815498126062521435466470249260770991509346188165263075287850" + precomputedLines[1][59].A0 = "64641117922972289412694403199160613990999019944207301201605925020441765831402155024052436276304356582448281076699" + precomputedLines[1][59].A1 = "115083312261402566141714078516986787111307264811665914936689089259045454992142142548204677044969906296967380044243" + precomputedLines[0][58].A0 = "67915906112893626131410758754363498551421413835839777031678634532329280037835197622992623489740395898322663148059" + precomputedLines[0][58].A1 = "173313446805658717197371811890032635005436241964263426050047721320291430359915009162322971665705181058998559235263" + precomputedLines[1][58].A0 = "220652074315326818132704565940002651528205251595436277895602495745266339816144887682182133039083839867027773044751" + precomputedLines[1][58].A1 = "102992145634896269413545294235366136092298779070167486700388724721785865183766315260503148313113902958119600238801" + precomputedLines[2][58].A0 = "24730766400316803857946042038542628446851566775501406429049092024956572334933897266525719094500802651967451877104" + precomputedLines[2][58].A1 = "241752628149868566672012188330652636152051946564517212780585454017277108424559367626456680725498525788959515586135" + precomputedLines[3][58].A0 = "30542136058754465199617571707255509790049946932941210501104976840541095202572961395579951366462651511063666296618" + precomputedLines[3][58].A1 = "188144464578529929130563311752704853358931809446141674400393035331183849353780969440230131118220745630710695729038" + precomputedLines[0][57].A0 = "203542720507623809237305077491776878815828671640739858075525189349207756813675901299937136298538210670943015983066" + precomputedLines[0][57].A1 = "122267097877800673355872550137995253009203160241772918862861823846916861120945377873440152245885886575116820447218" + precomputedLines[1][57].A0 = "241010188933240099439421890421590377132751995299722266619415194349228859514279889855603932334007400482783077583653" + precomputedLines[1][57].A1 = "246568490291142875524155284457991940489812407071546528381830727409319336551736503146335999393390849501813528441715" + precomputedLines[0][56].A0 = "212633273188386348461937963689047478311737376294722113582550164025934419771241861346880832815665535652196560209111" + precomputedLines[0][56].A1 = "52468567784554332368241391715335917402359805401532522717425862730964963067315969197228926701959222636436968893965" + precomputedLines[1][56].A0 = "101214307451482497743965474950389111293669486840506449543862843813549797384606690652197652506708811986285321910264" + precomputedLines[1][56].A1 = "31715746076589535848294115445740059535847765725766141583327080312355259209585058220005250980764055876391622751768" + precomputedLines[2][56].A0 = "168629293029304632286170568768702332236384447904909963327367145873993031054760955454127828033818710331443095639627" + precomputedLines[2][56].A1 = "216925787240186374969566846700812183945786718528836723540836037071059152966220157674422417496587882333270394842667" + precomputedLines[3][56].A0 = "227731030244902669822034034198287357304927285504045181193720954189047228804334056981471267342506406613130740169079" + precomputedLines[3][56].A1 = "251655417264240468159275057085684875405631462971234387395385359483060778954954485244685517807770934173420615035867" + precomputedLines[0][55].A0 = "69647542295523762035687275250742893896534867950558432281294053406455722968513626395174720962023816217066549750146" + precomputedLines[0][55].A1 = "86374134980846558048430600607671875021595489437559509552209510903378515568573393635649006113997944974205912317758" + precomputedLines[1][55].A0 = "29326477189608627032111067771967745146682542307757304826348053622190718175512195091725631166880091118078594012566" + precomputedLines[1][55].A1 = "174955420673204293542385641500734311605993640675970003481742264895699917035936887030918794628670844388976105483108" + precomputedLines[0][54].A0 = "171991363708194121472978060584199292364945187003016243383405563270304667365710332488330045122007137069329885369470" + precomputedLines[0][54].A1 = "243426352363069697911439121948918250088994837794888532959202848854408417906388388289358093612487467125815291063404" + precomputedLines[1][54].A0 = "106063100608103983067598257364471593632978801976409388432805708780404665354701495380114490637950028503103223559637" + precomputedLines[1][54].A1 = "188329967143844604554866061124618250676783471373114471959031762693226916978037054228364264440383353204211791706717" + precomputedLines[0][53].A0 = "214091666653090657319542665545941683440808786772326208699185056563290318417156204234941530065737811054981031247210" + precomputedLines[0][53].A1 = "96731617711626976892675884826845578958560766542574571579820430348168005867708261830828635866548830196001915529050" + precomputedLines[1][53].A0 = "143883820766051764850671495947961492420458444198997759644961583697549192192176411080086439209274633796205816834399" + precomputedLines[1][53].A1 = "143317514483005534418265519008755966439349456133008140442668309237863113100665670627540093779152241182561445862735" + precomputedLines[0][52].A0 = "151497223182486129841199062372509400530300996571744152646342452843296601418668570558349814149919457527539266078088" + precomputedLines[0][52].A1 = "188402120901762789029086895993400023976410651095772897842771932991017017076324967942540432856032705122173338834187" + precomputedLines[1][52].A0 = "236110368897175739403452298877857899405720926122091885775395967791713064198858559330397521016851953740687593385237" + precomputedLines[1][52].A1 = "185604871195988167998901933073566522815865534397868253229103650676649500057504419318493183044365511529428248776365" + precomputedLines[0][51].A0 = "151407873777411690678167938025891707857729360551843175521901080100924614112923800706449003238478817712558493254487" + precomputedLines[0][51].A1 = "124635153072321651137725119039276110305377634452310724384621256437966948390352207825248778661122886657978424705493" + precomputedLines[1][51].A0 = "65884731428159867714510181101890865971915393477778158291323219872508957428871319717921955217274758176624965255438" + precomputedLines[1][51].A1 = "94350765353828166168386200390905832733466928657507847537299328518177074423738323062253931346326414218705308941110" + precomputedLines[2][51].A0 = "144929072355365985674902738845501550079936376328476619318600256877561349496064309179237541842092362342195745784891" + precomputedLines[2][51].A1 = "44437147980879789866156846994099291093325312579185235864854239256726963378686220647533039481853201359084625125764" + precomputedLines[3][51].A0 = "155221871708143047682282710190382093696472045601337348902336872864615118871626157839345034967784906688581215997830" + precomputedLines[3][51].A1 = "71232907182490356033503979145028085712754143079663077216895622204621447881932676868270369150591172163126933887254" + precomputedLines[0][50].A0 = "112476898280051652932023240471267828132672683251091725917551660130621574850722581608621020651843931695399551394991" + precomputedLines[0][50].A1 = "178023757085448589000615475717388330579337562549560411713341146156044567124656718848186482950890230735597057216776" + precomputedLines[1][50].A0 = "99003996964022849715614307455124952281424652032125615162575128143184541926999502007666154991759624135964376019616" + precomputedLines[1][50].A1 = "52182858727791628909383427365500986520821899364350005707885438521874175971603243008768806573942703365532899093053" + precomputedLines[0][49].A0 = "65088282562948965761190662990830186380196959225738756558164334049580032043491015118050228600228202415441639672869" + precomputedLines[0][49].A1 = "223853433746839137696915628847608098601847228974179470986546755274849494275230146968036581448636309169459465176715" + precomputedLines[1][49].A0 = "55848217484244029410628743210585829585404368353761735581189152004695531401723359570629207819300644755754229776649" + precomputedLines[1][49].A1 = "47257613438405089663978967373233487629269597222625218814364763754867986586668323064958116490917460709577641319233" + precomputedLines[0][48].A0 = "33371276088436595465554944025407522420485220621018556569340671754687239890695554901739590443448140506769175564774" + precomputedLines[0][48].A1 = "31381776283940400693221103192835284141966270545629722313240597597473480872241895841695614446464420188420594815483" + precomputedLines[1][48].A0 = "97861431335031211212906524011909118216384280632977794574164978662461994919632244167477241209517957586485933813609" + precomputedLines[1][48].A1 = "14950676222934864970633486268161359604502796325909833831243020937644755838827844704908958417814623086408503531807" + precomputedLines[0][47].A0 = "59834359676924689103154622724559128302632485935277896410932203772395994475951292749159484763341691875467054215455" + precomputedLines[0][47].A1 = "121634830616628172793803878988299141485214443672514101561218120159913380519627491388651347461061811952889290532465" + precomputedLines[1][47].A0 = "104281214492117378808763969241391039252455168270214079366872079330480071719275354884341891448503989762420171340214" + precomputedLines[1][47].A1 = "68103241133974496648408070804835851163580285438904517607677472422014852700937120533228822433511355961062367320174" + precomputedLines[2][47].A0 = "168901673321089460931791758561025999148108036400257179421892024848602870259442595103430265873615973573420923314714" + precomputedLines[2][47].A1 = "116275983274326047668234817034342255449277965667042329332784724780585611389022704446928891120293152846354930265020" + precomputedLines[3][47].A0 = "75383946495203059731777121283679121000322954233531175456797523085675044614481601893492011669513364003858404281212" + precomputedLines[3][47].A1 = "32270804903692977693643352218587823508757216825781198439889489590138659918710102917314675600818261900391149654840" + precomputedLines[0][46].A0 = "76723276878327012045456058068102287184294072829295527286187448792759564489428281529704081028006352317435184449104" + precomputedLines[0][46].A1 = "33725934849602081359546849665280761456084295644277595438596842833350389330281014869114061231036427556949488967269" + precomputedLines[1][46].A0 = "1809062309286707529767840046230996711777004497969761910017468820698645646048993439725644550617131322444632019900" + precomputedLines[1][46].A1 = "241500766141718735641202287847083326226537429024237344407386495730151262921153141998514689577129755244952085473074" + precomputedLines[2][46].A0 = "14086279004459732270127334451590422919464384944294865439422851990271034317463472174161465876503634166002513358089" + precomputedLines[2][46].A1 = "225551169007861832921853726859252182731015316313812049356540957604226800563394509687092769645478358553628759270468" + precomputedLines[3][46].A0 = "183394982930716068751100083054933288533427117008948733485891009892942959050290569100665502177264899837441225397432" + precomputedLines[3][46].A1 = "34555061286391675012830223729899620640493078684911013075963842884474502832159313157533963215606542520028113910358" + precomputedLines[0][45].A0 = "84051322204090599322425691524798998095080466531985348228942107386393519084245087070637510678562600608969546678159" + precomputedLines[0][45].A1 = "224917158100474648123106712022137642776748421497756802909561002011563372758902952990431852972479759236554834463550" + precomputedLines[1][45].A0 = "114827928772035709256184122092025802807404696460525247358200827896023009003960574524841905128894572546212276492121" + precomputedLines[1][45].A1 = "98658418705502063530932941367124832253134351085930947647228852083467613663016036180612406002425924201990406005448" + precomputedLines[0][44].A0 = "7338872593430395006967912199087063219867840949383351191229366459898328840974258707799908493492319536992684737650" + precomputedLines[0][44].A1 = "146787497226037867596003983591424656653929510707614932816975953611056981862606611006751504153498756340687711813801" + precomputedLines[1][44].A0 = "23664613469726311529956696702053198756124772279740534032633117926205723370549721151880632478370417510043219515586" + precomputedLines[1][44].A1 = "254051787364760479648312591248287311944374116359993737392991975958518601767075430645350638515342403490806557403099" + precomputedLines[0][43].A0 = "109942451867557999843779896343871792864720572419192881487620380410753920800041133630326027817995559103265780802440" + precomputedLines[0][43].A1 = "218241193079413494073835081372801019824529801077267146179570351108550864993185917782859937827554117105001468972780" + precomputedLines[1][43].A0 = "29397114841983252768870751792863676606319556192631896418556758999267399125460333967781697401894655636177338251399" + precomputedLines[1][43].A1 = "121555286462688676679872453571868600224278136049532708191391921303218022273875291382703754044572514894810730162819" + precomputedLines[0][42].A0 = "184348357462969621855539644658812582967307983933034524109695268269289046000306196901346532564393269600353450049134" + precomputedLines[0][42].A1 = "155100031062974332654716518911833449249043290262455053566725575927816256962093671945069579099144255466346881106065" + precomputedLines[1][42].A0 = "115748627222217999417077585035147819732664349695232018072763165430613851792416146485676051786575407367329821009341" + precomputedLines[1][42].A1 = "101563556224286525341844837089004722548555297197270313547530612111737117999605596823473729414679565492567086769680" + precomputedLines[0][41].A0 = "201608486742718044611140908558139832942001651814545122400319629100992978249279567900744778389980913295352905512100" + precomputedLines[0][41].A1 = "193206833379679431060900354827409004443174464492903309613288257868758094136906434507676870318548079348888976113929" + precomputedLines[1][41].A0 = "56705622085037677448744812351889531891701892017576813250650340271094034648438803883256590281116794788775203237199" + precomputedLines[1][41].A1 = "220338221203257637946838563533059324341186623994837464953807785239794517068614081945083484069810517435918017491727" + precomputedLines[0][40].A0 = "85346577680801650470152770481893383991131333806437023603579497038884493033260549525345382960501338079722959491208" + precomputedLines[0][40].A1 = "35258068609740860538437593758156215249215806306551023929392762629752770796334473790402643565352295336678424656918" + precomputedLines[1][40].A0 = "186711750081906943179854367614401611898912692341262326737039586007795186153901435430878972306416326706038974036805" + precomputedLines[1][40].A1 = "139771768935052416715472595186523680802756126710813081174378782368365567654466903167005868899936405818231705738535" + precomputedLines[0][39].A0 = "215315380307196542246112644413032283334653056973473566502333530180998587039204515727312607670717947951792149988666" + precomputedLines[0][39].A1 = "83221370915290535695206340365399040207108851453836793468373757967278110141491245993321574774991771764252068213119" + precomputedLines[1][39].A0 = "187858159809479601598804605473835073651822326188048559434437714156022398470774814919936255353742452108515721156199" + precomputedLines[1][39].A1 = "188381703421717574804834943169997798742216470113887435075335233291158883898539581783430931360817912379273592222184" + precomputedLines[0][38].A0 = "24279569550991910668493079358195722719935007860518421910885476101812019854244587495951456936470893042305511323609" + precomputedLines[0][38].A1 = "137180635674096437935240294526074297322948216588415767737614952517527381848784610884502437230607108469572855819923" + precomputedLines[1][38].A0 = "234996204122927296652190204011716899749982730858672133878163718237645934119674470543742019170822471050807175927766" + precomputedLines[1][38].A1 = "92088885326255068919774110794098747808935609330103933500815700065926407984107332948756481559715088823140128682938" + precomputedLines[0][37].A0 = "12543992453036820805550342319209985200002293098642835372524920572609746076879723801361398395134241729458130104210" + precomputedLines[0][37].A1 = "150778101137020653679767211218697643081244007910741020138378495225715061556275631366940680022570064237554051394786" + precomputedLines[1][37].A0 = "156760423817321536652593670864554732957893617243405325139317813356786332981879509124800524269480378021334889717551" + precomputedLines[1][37].A1 = "138003713010895804157918477030967127644526487390835363622139195712252722364691214286434262332793714796835152059770" + precomputedLines[0][36].A0 = "24593304415506667023929194445269459675314941525437594101181618356894083519729400496565302361491137759007806558090" + precomputedLines[0][36].A1 = "40153328340430495636692843077727165888193778535423231663304969684663585352979484180061719799934938463821931715271" + precomputedLines[1][36].A0 = "238838536265404169034958293692294983034722654688261650748346214443421616406277227296124010869807080078237272058320" + precomputedLines[1][36].A1 = "231390648460389860655490242598518915268739735146387985877640100331563610665848738733122910410940489956989121920776" + precomputedLines[0][35].A0 = "251660404696011926907640668468689181976214510362587102660442691056218630886488060204044742794647442627470397227897" + precomputedLines[0][35].A1 = "7684887556551818035030313634725849355291883643695237841810995080771859954078538088181443891924734161777589308500" + precomputedLines[1][35].A0 = "195826319269118506411339469194468517979272424347723746723371862135612352670122475008676708792059982442535891396639" + precomputedLines[1][35].A1 = "159999919615779760351477088696848991749266782570106721609670770909436249046576459239697856164512584190403379602794" + precomputedLines[0][34].A0 = "197398646574132733287596831073978800879329655929121568878114769908643461810133004339970499467775967681376383106798" + precomputedLines[0][34].A1 = "96431253682360775275444638445028437884681084039928049658575397299591371801501475195489594601700429006425001153640" + precomputedLines[1][34].A0 = "257204781005187722945056255287079703727469022329784792426311312302269392950751646903642446226964182189565854293712" + precomputedLines[1][34].A1 = "240395372141813122553617389999703217960182539426100845071432690677330836376481074651678730334729130897353492662646" + precomputedLines[0][33].A0 = "28823977749324723652874308810272935820298443214932765824829850165753831842772136275057949057201937267323697696552" + precomputedLines[0][33].A1 = "223530247513752705701791979446230877206213338864560788550736847993798583541041590318085856192118493620990311389211" + precomputedLines[1][33].A0 = "158254544463932260307287354230604631601829465348133332757429290835080473412829320965860720952419508462438381692592" + precomputedLines[1][33].A1 = "181294449322636680473354250653911203936621049302388462691159447350018908622817786189218002808087785788031195469899" + precomputedLines[0][32].A0 = "143257627767117530029385249339945534526833185194529385799929621468636676347591411537426673375647638870773079719646" + precomputedLines[0][32].A1 = "251943136385092701640802622422781044470061262096491741701685500502937273178539785178790805403819217373079877417179" + precomputedLines[1][32].A0 = "105131241159338160701691424930954885768888402552581762740214020837757553962048658044366292035955759106395905075368" + precomputedLines[1][32].A1 = "137108519673071977400147187044710572459811350963278670325046016668463076208398815019550818616110510887136421606845" + precomputedLines[0][31].A0 = "130768925408234755650455498275318211701981697071492959740224626831026093390342157671893173195261362770808282774347" + precomputedLines[0][31].A1 = "12330099732690199070528098741516911002638271167059377791787103885132545809617469842220531842476816339316626509205" + precomputedLines[1][31].A0 = "139904524462105384763616768088760813261459404593868669225374394285748797412506127764146079104074529582413578963024" + precomputedLines[1][31].A1 = "181639279227666744728058835182512597002272731052592483796260667304563267807641292586038560231233856920925765682824" + precomputedLines[0][30].A0 = "18435023068089070660975125510810348439226838453192275034433072101321273131670724107924928772543708251389256210982" + precomputedLines[0][30].A1 = "189563661045795839798194513512204620616132859547311231887706974570325317681777303888732091441417807222126570541533" + precomputedLines[1][30].A0 = "82879219941078362237057560048807799890255744733402294525751276088980660814954382745170509104342168803628139752230" + precomputedLines[1][30].A1 = "214832037284986972918877188428974385639036750829472015494469650325398240425765032117430797714433823602868628599745" + precomputedLines[0][29].A0 = "42631895223927910144857931606909345494297658132607842384637215440118798866582309559546170263267972096485951244567" + precomputedLines[0][29].A1 = "77696136161382625418126500663459926214910905680686292190148360724475155184380854358222754908651587605503983046743" + precomputedLines[1][29].A0 = "152618360449309051999677256436373979257431914081079963807076984865761018415896770491445227308914139716521343272725" + precomputedLines[1][29].A1 = "26285601224874509462060998455784006072413538379277541730344155792924114453840379473296234538995690755698646672930" + precomputedLines[0][28].A0 = "24059256652944299509671812434551197387011779854490927729047728896044313117301846809233649351717543504845423468003" + precomputedLines[0][28].A1 = "35173868244825360731659516818142227782662997555430669343484562566533614553227431870180137184272899341963674458460" + precomputedLines[1][28].A0 = "26667394460546447876306834413892828289693134732864745652367340112561403687205234797217061137978822136998709173641" + precomputedLines[1][28].A1 = "159123407327533078699604126675314174180923011431096322695436611192287322573041539140147065743473972354260619089942" + precomputedLines[0][27].A0 = "136137288084666935649747488173741021676298106423657578185667532171055309995913193363233363931205272154620190114019" + precomputedLines[0][27].A1 = "89545627262110858618816260875502743018263833576903648871523115587806824889738746993902862119794958027648141908496" + precomputedLines[1][27].A0 = "252851890965596817007093722301557145686460857050762542196422275804544803647600405475604764180686200549884895433483" + precomputedLines[1][27].A1 = "211895528043297520342501334603623706301609616744938516319297789235751079187131514524137929678308687062877098810160" + precomputedLines[0][26].A0 = "50193146133047157273064142352635336606881019629626521778271784631258435754947840578939444032235429198156271337074" + precomputedLines[0][26].A1 = "187308061181887333305695393482067401486224313057210050665155383491225533144410006314190041246753308320130016613637" + precomputedLines[1][26].A0 = "183380713641524274675225232799185038264979602759013420895704652493807171176303172602525771574777626869966774842497" + precomputedLines[1][26].A1 = "211536301540535091190968313434777773205354151446626615094066362538939731040161694312263360203075496146152316874857" + precomputedLines[0][25].A0 = "49012561462424313854474827884810003215689871066458028597824053702588937311381406457503485197368540251490314034705" + precomputedLines[0][25].A1 = "233045787647737445522164033006937078916016324755882167346610048475859694537814370111952019396064867502666410244388" + precomputedLines[1][25].A0 = "53066323260736050373339130891333710639264834305000587376035160979315466065340287620198139746778954661453434541474" + precomputedLines[1][25].A1 = "29068598336974169070319221245419308506232095719649983758674044085550004221937258966074793560744319380821942029699" + precomputedLines[0][24].A0 = "20807998187272051641020299350036585881225828617094479315331980902443389723381915716261749249127707131179558982418" + precomputedLines[0][24].A1 = "2569663689378405610914568367524137984603047701932196203695535597896611363401435473912332128808636751057477346538" + precomputedLines[1][24].A0 = "31632356134371149321558379637836594093061670897101626940219252979141342890636973560289020008142158068752506236989" + precomputedLines[1][24].A1 = "162474034324462109228316512108971901178903590482633712387596206505160144691740841684973770349748752843856786461608" + precomputedLines[0][23].A0 = "129839621447024852204253788712951095861782376933163112694737897301784545514246565506499111741363541362828753366944" + precomputedLines[0][23].A1 = "252032500019395610356302317835987533665618352267173289007638823577918726144160569337652190692921020587072209400233" + precomputedLines[1][23].A0 = "103657763587722609044718311659913184275449540326928197212006182064775110199731542634341445972032889703662079067270" + precomputedLines[1][23].A1 = "112149295095636864542149228797329324731936818498378753437139587454413070852255126624126792032729914106818657977248" + precomputedLines[0][22].A0 = "52486947262637987748819651641734707131186753478990776483745888259487759278168892849920956440670013928774458823031" + precomputedLines[0][22].A1 = "130026131370368540937875186024276329689406518464978751396814138888168900897492508443953815599508534025638875048803" + precomputedLines[1][22].A0 = "107650629876464582636360157762068367991039629923335616888987246523659237458989457421832659567285764794451349527208" + precomputedLines[1][22].A1 = "141954369657977128951842006362287940386733613050487111689651115226149144952783808012994761182674964220084045831134" + precomputedLines[0][21].A0 = "231357853912571244811290186339647076350925128723982887136028779601691957836767776092764975073438998690778954042539" + precomputedLines[0][21].A1 = "163918374503037182075190710665683594592782915546267809877595016976354408186482637370451964750667787650027037177643" + precomputedLines[1][21].A0 = "253993499558111303948568836475719815650044010286172125286039237179604540180534497006890615118966317984491781332188" + precomputedLines[1][21].A1 = "91306106708399550536330826339834147222693050466834310408418697204838373472953516606442984599917896899599804778740" + precomputedLines[0][20].A0 = "149491755656432546060358592364992096426848624317881230311541543281657412565012173169687489424412241084935608737415" + precomputedLines[0][20].A1 = "10078230219440146386786363208158695607136749719344080072245090598284598926356535194231918833499131426671949138529" + precomputedLines[1][20].A0 = "166147337894102898100227035549249329668949790996210759646973072283347219527930468759756291641199437404179464357414" + precomputedLines[1][20].A1 = "51861054845967563098625817334634451877104987996694111485430298623777222199857885925165427648742064004455715026942" + precomputedLines[0][19].A0 = "28789169380786124146230039268084419149725235779893793593175041494326370399300705019113627565675012781002480144908" + precomputedLines[0][19].A1 = "51126887514521757575486479441947876918479046042269689861009267337206224784506047520042720016113428299273812525856" + precomputedLines[1][19].A0 = "78438693983982849393828413461352865239037292515401621197446705858742631866248615088464959258495107540867637019515" + precomputedLines[1][19].A1 = "30329505348723895036030070665431021879915138933160836543969295494737694955113148118481620605657594147854768977836" + precomputedLines[0][18].A0 = "232629455467381604850710076804873467028376110142364182928813051160289780652005784104979994611269911467485081182940" + precomputedLines[0][18].A1 = "11302181816175797313086942441676022578093450171261944005064665818199106423002990602323812409511695564778846486504" + precomputedLines[1][18].A0 = "112798629536320588320184285596464669659963010507488150416465775170471633583892504835346009133427451338673845958926" + precomputedLines[1][18].A1 = "27926580645333909599189868163648675153725875315970675391034557886751990831504045873049028135712872229090867974558" + precomputedLines[0][17].A0 = "115880852955533259983237547745802961730413488750579029504204913129940923720969712061455096004446351190825278987574" + precomputedLines[0][17].A1 = "134447633343807307002446878490529517345772145910295671415051491245078834184672094129755814689008997306484750672710" + precomputedLines[1][17].A0 = "22976552345422841104357728976423929852837525413407609214580909257749715608547087375217990967883923048681682308605" + precomputedLines[1][17].A1 = "157338698204491688132182193212504377919056836489218797119791952057900152817029465512703604739201110946672619204010" + precomputedLines[0][16].A0 = "223616143440112228991226989175091852969974948963122974521782410945509189615455977425545114009370493631838263568399" + precomputedLines[0][16].A1 = "204456530467684524281666624398716928398144704169310479093451616720328008867922931285712172430932066182333761659341" + precomputedLines[1][16].A0 = "5323234652857550745006860986786657094417255769060869381366387995350320947937193112663310128013625083305442218044" + precomputedLines[1][16].A1 = "130146525837735087135101786226894655769855676783773117828949024726687714195186092464866196227118662795909493406397" + precomputedLines[0][15].A0 = "35879762568593404521748945388121221928728456093111364511378254007576876239340745413450484633304834658392189522640" + precomputedLines[0][15].A1 = "189498163278461679432566273882294448630820178631341640667038207983840782570490822357113249076451752139873663143984" + precomputedLines[1][15].A0 = "70018852018551064324295536813587450346145659851468642379539210391979909457494870674966791996232606206993522459149" + precomputedLines[1][15].A1 = "90989240347122377994989913910001231432408075720877198347914691548987907164311952853485914212956028662810039358245" + precomputedLines[0][14].A0 = "49450616355009745073590491942349172641952447152992247519074453655520247591578774135011510212578195721686624093550" + precomputedLines[0][14].A1 = "12401442488680350856906608876104073482937429970806311895185470442443537365780479306427480274332087392750168338843" + precomputedLines[1][14].A0 = "239689211153486811252647709288510228061183578519613821643836148328129159496137010243388999000917180876509629186445" + precomputedLines[1][14].A1 = "2729693877011151850107624210373087413500134987774287943407095881238706058575973718334889626025410795007673129883" + precomputedLines[0][13].A0 = "224988846995901048602625674742964253154015552705325289016691103852891039786902546643150379715775366758870154456997" + precomputedLines[0][13].A1 = "129864077084375935127482368750154719138214019398521495103424546330219522473893632167459907912411228922504888485222" + precomputedLines[1][13].A0 = "257031197461954365270816098531134223468191890363766305522039173385634097452987309754237959512090527959223721572830" + precomputedLines[1][13].A1 = "122466151103419952486921662302417561494838288721480283521900993826876057065861889914651643741158715449734227743491" + precomputedLines[0][12].A0 = "77519089212247869512031452152361396700728443048834277046251446725191778430518665116455061523806578971515283750094" + precomputedLines[0][12].A1 = "216670799718102947497582353087794901569068081670444094081762144953436869434434429802205553883713179846770077888679" + precomputedLines[1][12].A0 = "183736750761953353428489658887668334264927966306319538675377272137084990606187126401097085903227676516857626620668" + precomputedLines[1][12].A1 = "33441425600543503095646506536272497008311643149746166957182997383389484599074905331740705533550039908467616058278" + precomputedLines[0][11].A0 = "109125999944265251051497250748848286314071419548578469603532176578057378123700584787703393908715413022302163824309" + precomputedLines[0][11].A1 = "191939541144374701510408179698968498811022562839363731627137525922561117664858398125448901177976411529102410827368" + precomputedLines[1][11].A0 = "82152132620638790059884104207024850791327826277801975666595875747028105840184659638846772692098165269675144781862" + precomputedLines[1][11].A1 = "7559927451068342936102765326913464360946087131141717784263991134444057480419251387046656228759047670546351117824" + precomputedLines[0][10].A0 = "166003711242129986674754155224319713028013274275697813925647745412305024975109601702927132458269708361831458147826" + precomputedLines[0][10].A1 = "140704354002836468730319871968472729446356510943707733767939545999605679030972007641847778004648892723657229645363" + precomputedLines[1][10].A0 = "138019027947902914668850687441800860725057189329171152827619148949651871647264993108750188014968050196266004048380" + precomputedLines[1][10].A1 = "175127105734089044309030319546852758402806450267428794409840892693197579050851248742153815264025174103795301941752" + precomputedLines[0][9].A0 = "322884605698676042736054982015762606887526209186036151886420493246880916880799974236934580169202877650518546479" + precomputedLines[0][9].A1 = "174574765301690677540171845377067190823972566038343757997294104332498229185961167196415696738496909920616844163005" + precomputedLines[1][9].A0 = "169827334695549978591135913938187156668844340713888059950523304844492499936993789835142377407404720189263860755039" + precomputedLines[1][9].A1 = "35743493467144760984862953496808989274253933594783069863302385497051091216341663932869884538938015250354819237699" + precomputedLines[0][8].A0 = "32877562857320275788294047128398910085417426883567159999018805395377102673727994118415489719650239064460847089629" + precomputedLines[0][8].A1 = "31281084859624716807763602831200447174635957976077576021062493802913427228827105855241149266034396030919713564406" + precomputedLines[1][8].A0 = "108030717014344148010003904135714795510428018303002360364178990413071781472343962564312043108941879783604395368712" + precomputedLines[1][8].A1 = "210920819790927220240971586181206629969352692717991107124786314274585499721973906447093310006677728331067636782651" + precomputedLines[0][7].A0 = "163427475217922738641993294129855923985645687152434412554872630371379102562823372497784029353917398606260239730258" + precomputedLines[0][7].A1 = "118819530940531029789849947204470232664797826181684365255839645996319519181478722933215883649290521127005999139221" + precomputedLines[1][7].A0 = "13886471993552988264390003817290840044997379111121386623777718673275006501485512955768171228070764389654182080024" + precomputedLines[1][7].A1 = "256948805864750546745378004297028797447468222553171160329191926068258555579477559818624921645670232448539395914413" + precomputedLines[0][6].A0 = "235030218538732571594250349871077492193085794176025976161528531669237500498229418079035022864807313235788851195612" + precomputedLines[0][6].A1 = "210363466204372243614806805250699091641060484734886304595764963532232546623588780423209970069371515691148877349449" + precomputedLines[1][6].A0 = "124826981981928301871797094724110753799965219618845362320803071236043891526785269398129716884281973465691829216198" + precomputedLines[1][6].A1 = "216362246149948424396144755031495986241303926845685977547670069647525973371864236234391057292556666685367413840493" + precomputedLines[0][5].A0 = "232906800811808919404891009707933938887113933811523804887536420044186714360947550768187449840531962591906627259154" + precomputedLines[0][5].A1 = "59055492302464996892962403082978279957286398889782926647107737998877732041477724516206993363813348869969124697307" + precomputedLines[1][5].A0 = "153471705388743237012441724371146910052549757476269739551587302400380635478788805654190845813661067011056444155754" + precomputedLines[1][5].A1 = "101922595818476693878719395451944337822842877286370094125282810221017979628480204964772176412992239541707269478103" + precomputedLines[0][4].A0 = "121420039214462101187725026359002873331598638102506997064698252885119870075012037867595075416782973292357064496357" + precomputedLines[0][4].A1 = "75564152468492694504348179055737273531330007723576651066802575464425237521194211376137171365078078737350290111342" + precomputedLines[1][4].A0 = "1593097566767576308029240508154533701656830081962877712965609998675149706396084144955694178086208240015616818444" + precomputedLines[1][4].A1 = "184062530242309291417160865801296211238421239909716053893556814153411495051332277671675035633671375254500929302435" + precomputedLines[0][3].A0 = "51700427750739671073611575202441559572073914667485738084618077310268290738668165174520844478225495643393861156048" + precomputedLines[0][3].A1 = "141584254129156301900561396028177758060423900718960649431021157018061307634951820523647478050712231401904804367159" + precomputedLines[1][3].A0 = "90189952007333137629335565771731594978910318916032598472031865375528994472165951170958611679372458832375137522023" + precomputedLines[1][3].A1 = "76703673612885769640131704168255930937306043945886427070215224890710272144946893774213854855521006725817111381920" + precomputedLines[0][2].A0 = "59716789251246526722445639828306628576038133617256894226448485037108059044016927549847970668455915682664331588702" + precomputedLines[0][2].A1 = "106717425920252455352014036708990000212206835288395632216924125801742378053744714107220258846139127231904389527101" + precomputedLines[1][2].A0 = "256743282762562807243623403317318998557234115078959890143446252619791292256666977175864041421282293738880557674396" + precomputedLines[1][2].A1 = "164597390837498370921852664427703952193302848530321249548502266241427395754964020479461823749428885111841700250722" + precomputedLines[0][1].A0 = "79305776554169698360957316504345121263643267795719186027863032758503762988343397060292054901085328228848010776838" + precomputedLines[0][1].A1 = "141527969714538781023725045732556938872320350522162319949153687911416209684200721327501370003173850035523727831599" + precomputedLines[1][1].A0 = "5424524198952252590410874133044949072822691120863041208715731093661227503555532826194168416666686402691663272534" + precomputedLines[1][1].A1 = "206701053330372021428751381013907839852086135598606451424090781620585087343475908611015073210245729008320956185705" + precomputedLines[0][0].A0 = "21918509549963353142563477210743590667822930653042041256924224080965352319418877195339836372078921609564661330896" + precomputedLines[0][0].A1 = "120934438634724038292959651292759203567355061095965298261715865048556477266515749474371641099553611155110796888934" + precomputedLines[1][0].A0 = "107423415014442786374537065283151248683544228690118025791369207390553414543826660054936565666605678660756414043296" + precomputedLines[1][0].A1 = "246657105382086152013987446007025836646804684734192383499579797135722249112691696484229304825231914433391408837084" + precomputedLines[2][0].A0 = "19717807836209176546596875341898237675777494083291898007155826480671167938094146804293770861515390512968035906242" + precomputedLines[2][0].A1 = "253448949986965201640142015487193666064970049736875989459160721495735068843457329330691149558024967345766685055031" + precomputedLines[3][0].A0 = "140484755385853800980542398118706659319129985018524465262910162212888011149435366569399119253527824495951283626988" + precomputedLines[3][0].A1 = "47750277710932407129925646204660662972416796147681200498768623233120086435692032956920132961212597993256915197783" +} diff --git a/std/algebra/native/sw_bls24315/doc.go b/std/algebra/native/sw_bls24315/doc.go new file mode 100644 index 0000000000..9a13dd4769 --- /dev/null +++ b/std/algebra/native/sw_bls24315/doc.go @@ -0,0 +1,8 @@ +// Package sw_bls24315 implements the arithmetics of G1, G2 and the pairing +// computation on BLS24-315 as a SNARK circuit over BW6-633. These two curves +// form a 2-chain so the operations use native field arithmetic. +// +// References: +// BLS24-315/BW6-633: https://eprint.iacr.org/2021/1359 +// Pairings in R1CS: https://eprint.iacr.org/2022/1162 +package sw_bls24315 diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/native/sw_bls24315/g1.go similarity index 92% rename from std/algebra/sw_bls24315/g1.go rename to std/algebra/native/sw_bls24315/g1.go index 1e8b12fd0e..61cda042ff 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/native/sw_bls24315/g1.go @@ -22,7 +22,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -228,7 +228,7 @@ var DecomposeScalarG1 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG1) + solver.RegisterHint(DecomposeScalarG1) } // varScalarMul sets P = [s] Q and returns P. @@ -299,20 +299,18 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y = api.Select(s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y = api.Select(s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -451,3 +449,36 @@ func (p *G1Affine) DoubleAndAdd(api frontend.API, p1, p2 *G1Affine) *G1Affine { return p } + +// ScalarMulBase computes s * g1 and returns it, where g1 is the fixed generator. It doesn't modify s. +func (P *G1Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G1Affine { + + points := getCurvePoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G1Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X = api.Lookup2(sBits[1], sBits[2], points.G1x, points.G1m[0][0], points.G1m[1][0], points.G1m[2][0]) + res.Y = api.Lookup2(sBits[1], sBits[2], points.G1y, points.G1m[0][1], points.G1m[1][1], points.G1m[2][1]) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G1Affine{points.G1m[i][0], points.G1m[i][1]}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G1Affine{points.G1x, points.G1y}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls24315/g1_test.go b/std/algebra/native/sw_bls24315/g1_test.go similarity index 93% rename from std/algebra/sw_bls24315/g1_test.go rename to std/algebra/native/sw_bls24315/g1_test.go index 86feeeb571..c4e02779d0 100644 --- a/std/algebra/sw_bls24315/g1_test.go +++ b/std/algebra/native/sw_bls24315/g1_test.go @@ -369,6 +369,37 @@ func TestScalarMulG1(t *testing.T) { assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) } +type g1varScalarMulBase struct { + C G1Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g1varScalarMulBase) Define(api frontend.API) error { + expected := G1Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG1(t *testing.T) { + var c bls24315.G1Affine + gJac, _, _, _ := bls24315.Generators() + + // create the cs + var circuit, witness g1varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) +} + func randomPointG1() bls24315.G1Jac { p1, _, _, _ := bls24315.Generators() diff --git a/std/algebra/sw_bls24315/g2.go b/std/algebra/native/sw_bls24315/g2.go similarity index 81% rename from std/algebra/sw_bls24315/g2.go rename to std/algebra/native/sw_bls24315/g2.go index 001fe45f5d..cd5dc9d064 100644 --- a/std/algebra/sw_bls24315/g2.go +++ b/std/algebra/native/sw_bls24315/g2.go @@ -21,9 +21,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" ) // G2Jac point in Jacobian coords @@ -243,7 +243,7 @@ var DecomposeScalarG2 = func(scalarField *big.Int, inputs []*big.Int, res []*big } func init() { - hint.Register(DecomposeScalarG2) + solver.RegisterHint(DecomposeScalarG2) } // varScalarMul sets P = [s] Q and returns P. @@ -314,20 +314,18 @@ func (P *G2Affine) varScalarMul(api frontend.API, Q G2Affine, s frontend.Variabl // step value from [2] Acc (instead of conditionally adding step value to // Acc): // Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q) - Acc.Double(api, Acc) // only y coordinate differs for negation, select on that instead. B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) // second bit - Acc.Double(api, Acc) B.X = tableQ[0].X B.Y.Select(api, s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y) - Acc.AddAssign(api, B) + Acc.DoubleAndAdd(api, &Acc, &B) B.X = tablePhiQ[0].X B.Y.Select(api, s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y) Acc.AddAssign(api, B) @@ -471,3 +469,73 @@ func (p *G2Affine) DoubleAndAdd(api frontend.API, p1, p2 *G2Affine) *G2Affine { return p } + +// ScalarMulBase computes s * g2 and returns it, where g2 is the fixed generator. It doesn't modify s. +func (P *G2Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G2Affine { + + points := getTwistPoints() + + sBits := api.ToBinary(s, 253) + + var res, tmp G2Affine + + // i = 1, 2 + // gm[0] = 3g, gm[1] = 5g, gm[2] = 7g + res.X.Lookup2(api, sBits[1], sBits[2], + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2x[0], A1: points.G2x[1]}, + B1: fields_bls24315.E2{A0: points.G2x[2], A1: points.G2x[3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[0][0], A1: points.G2m[0][1]}, + B1: fields_bls24315.E2{A0: points.G2m[0][2], A1: points.G2m[0][3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[1][0], A1: points.G2m[1][1]}, + B1: fields_bls24315.E2{A0: points.G2m[1][2], A1: points.G2m[1][3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[2][0], A1: points.G2m[2][1]}, + B1: fields_bls24315.E2{A0: points.G2m[2][2], A1: points.G2m[2][3]}}) + + res.Y.Lookup2(api, sBits[1], sBits[2], + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2y[0], A1: points.G2y[1]}, + B1: fields_bls24315.E2{A0: points.G2y[2], A1: points.G2y[3]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[0][4], A1: points.G2m[0][5]}, + B1: fields_bls24315.E2{A0: points.G2m[0][6], A1: points.G2m[0][7]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[1][4], A1: points.G2m[1][5]}, + B1: fields_bls24315.E2{A0: points.G2m[1][6], A1: points.G2m[1][7]}}, + fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[2][4], A1: points.G2m[2][5]}, + B1: fields_bls24315.E2{A0: points.G2m[2][6], A1: points.G2m[2][7]}}) + + for i := 3; i < 253; i++ { + // gm[i] = [2^i]g + tmp.X = res.X + tmp.Y = res.Y + tmp.AddAssign(api, G2Affine{ + X: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[i][0], A1: points.G2m[i][1]}, + B1: fields_bls24315.E2{A0: points.G2m[i][2], A1: points.G2m[i][3]}}, + Y: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2m[i][4], A1: points.G2m[i][5]}, + B1: fields_bls24315.E2{A0: points.G2m[i][6], A1: points.G2m[i][7]}}}) + res.Select(api, sBits[i], tmp, res) + } + + // i = 0 + tmp.Neg(api, G2Affine{ + X: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2x[0], A1: points.G2x[1]}, + B1: fields_bls24315.E2{A0: points.G2x[2], A1: points.G2x[3]}}, + Y: fields_bls24315.E4{ + B0: fields_bls24315.E2{A0: points.G2y[0], A1: points.G2y[1]}, + B1: fields_bls24315.E2{A0: points.G2y[2], A1: points.G2y[3]}}}) + tmp.AddAssign(api, res) + res.Select(api, sBits[0], res, tmp) + + P.X = res.X + P.Y = res.Y + + return P +} diff --git a/std/algebra/sw_bls24315/g2_test.go b/std/algebra/native/sw_bls24315/g2_test.go similarity index 93% rename from std/algebra/sw_bls24315/g2_test.go rename to std/algebra/native/sw_bls24315/g2_test.go index 0c8fb3ad20..13034f058e 100644 --- a/std/algebra/sw_bls24315/g2_test.go +++ b/std/algebra/native/sw_bls24315/g2_test.go @@ -374,6 +374,38 @@ func TestScalarMulG2(t *testing.T) { assert := test.NewAssert(t) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) } + +type g2varScalarMulBase struct { + C G2Affine `gnark:",public"` + R frontend.Variable +} + +func (circuit *g2varScalarMulBase) Define(api frontend.API) error { + expected := G2Affine{} + expected.ScalarMulBase(api, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestVarScalarMulBaseG2(t *testing.T) { + var c bls24315.G2Affine + _, gJac, _, _ := bls24315.Generators() + + // create the cs + var circuit, witness g2varScalarMulBase + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // compute the result + var br big.Int + gJac.ScalarMultiplication(&gJac, r.BigInt(&br)) + c.FromJacobian(&gJac) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) +} + func randomPointG2() bls24315.G2Jac { _, p2, _, _ := bls24315.Generators() diff --git a/std/algebra/sw_bls24315/inner.go b/std/algebra/native/sw_bls24315/inner.go similarity index 63% rename from std/algebra/sw_bls24315/inner.go rename to std/algebra/native/sw_bls24315/inner.go index 42b6355cd0..8630c645ed 100644 --- a/std/algebra/sw_bls24315/inner.go +++ b/std/algebra/native/sw_bls24315/inner.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark/frontend" ) @@ -65,3 +66,54 @@ func (cc *innerConfig) phi2(api frontend.API, res, P *G2Affine) *G2Affine { res.Y = P.Y return res } + +type curvePoints struct { + G1x *big.Int // base point x + G1y *big.Int // base point y + G1m [][2]*big.Int // m*base points (x,y) +} + +var ( + computedCurveTable [][2]*big.Int + computedTwistTable [][8]*big.Int +) + +func init() { + computedCurveTable = computeCurveTable() + computedTwistTable = computeTwistTable() +} + +func getCurvePoints() curvePoints { + _, _, g1aff, _ := bls24315.Generators() + return curvePoints{ + G1x: g1aff.X.BigInt(new(big.Int)), + G1y: g1aff.Y.BigInt(new(big.Int)), + G1m: computedCurveTable, + } +} + +type twistPoints struct { + G2x [4]*big.Int // base point x ∈ E4 + G2y [4]*big.Int // base point y ∈ E4 + G2m [][8]*big.Int // m*base points (x,y) +} + +func getTwistPoints() twistPoints { + _, _, _, g2aff := bls24315.Generators() + return twistPoints{ + G2x: [4]*big.Int{ + g2aff.X.B0.A0.BigInt(new(big.Int)), + g2aff.X.B0.A1.BigInt(new(big.Int)), + g2aff.X.B1.A0.BigInt(new(big.Int)), + g2aff.X.B1.A1.BigInt(new(big.Int)), + }, + G2y: [4]*big.Int{ + g2aff.Y.B0.A0.BigInt(new(big.Int)), + g2aff.Y.B0.A1.BigInt(new(big.Int)), + g2aff.Y.B1.A0.BigInt(new(big.Int)), + g2aff.Y.B1.A1.BigInt(new(big.Int)), + }, + G2m: computedTwistTable, + } + +} diff --git a/std/algebra/native/sw_bls24315/inner_compute.go b/std/algebra/native/sw_bls24315/inner_compute.go new file mode 100644 index 0000000000..158ef72999 --- /dev/null +++ b/std/algebra/native/sw_bls24315/inner_compute.go @@ -0,0 +1,59 @@ +package sw_bls24315 + +import ( + "math/big" + + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" +) + +func computeCurveTable() [][2]*big.Int { + G1jac, _, _, _ := bls24315.Generators() + table := make([][2]*big.Int, 253) + tmp := new(bls24315.G1Jac).Set(&G1jac) + aff := new(bls24315.G1Affine) + jac := new(bls24315.G1Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G1jac) + aff.FromJacobian(jac) + table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))} + } + } + return table[:] +} + +func computeTwistTable() [][8]*big.Int { + _, G2jac, _, _ := bls24315.Generators() + table := make([][8]*big.Int, 253) + tmp := new(bls24315.G2Jac).Set(&G2jac) + aff := new(bls24315.G2Affine) + jac := new(bls24315.G2Jac) + for i := 1; i < 253; i++ { + tmp = tmp.Double(tmp) + switch i { + case 1, 2: + jac.Set(tmp).AddAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [8]*big.Int{aff.X.B0.A0.BigInt(new(big.Int)), aff.X.B0.A1.BigInt(new(big.Int)), aff.X.B1.A0.BigInt(new(big.Int)), aff.X.B1.A1.BigInt(new(big.Int)), aff.Y.B0.A0.BigInt(new(big.Int)), aff.Y.B0.A1.BigInt(new(big.Int)), aff.Y.B1.A0.BigInt(new(big.Int)), aff.Y.B1.A1.BigInt(new(big.Int))} + case 3: + jac.Set(tmp).SubAssign(&G2jac) + aff.FromJacobian(jac) + table[i-1] = [8]*big.Int{aff.X.B0.A0.BigInt(new(big.Int)), aff.X.B0.A1.BigInt(new(big.Int)), aff.X.B1.A0.BigInt(new(big.Int)), aff.X.B1.A1.BigInt(new(big.Int)), aff.Y.B0.A0.BigInt(new(big.Int)), aff.Y.B0.A1.BigInt(new(big.Int)), aff.Y.B1.A0.BigInt(new(big.Int)), aff.Y.B1.A1.BigInt(new(big.Int))} + fallthrough + default: + aff.FromJacobian(tmp) + table[i] = [8]*big.Int{aff.X.B0.A0.BigInt(new(big.Int)), aff.X.B0.A1.BigInt(new(big.Int)), aff.X.B1.A0.BigInt(new(big.Int)), aff.X.B1.A1.BigInt(new(big.Int)), aff.Y.B0.A0.BigInt(new(big.Int)), aff.Y.B0.A1.BigInt(new(big.Int)), aff.Y.B1.A0.BigInt(new(big.Int)), aff.Y.B1.A1.BigInt(new(big.Int))} + } + } + return table[:] +} diff --git a/std/algebra/native/sw_bls24315/pairing.go b/std/algebra/native/sw_bls24315/pairing.go new file mode 100644 index 0000000000..ab9c08c3b5 --- /dev/null +++ b/std/algebra/native/sw_bls24315/pairing.go @@ -0,0 +1,592 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sw_bls24315 + +import ( + "errors" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" +) + +// GT target group of the pairing +type GT = fields_bls24315.E24 + +const ateLoop = 3218079743 + +// lineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) +// line: 1 + R0(x/y) + R1(1/y) = 0 instead of R0'*y + R1'*x + R2' = 0 This +// makes the multiplication by lines (MulBy034) and between lines (Mul034By034) +type lineEvaluation struct { + R0, R1 fields_bls24315.E4 +} + +// MillerLoop computes the product of n miller loops (n can be 1) +// ∏ᵢ { fᵢ_{x₀,Q}(P) } +func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + // check input size match + n := len(P) + if n == 0 || n != len(Q) { + return GT{}, errors.New("invalid inputs sizes") + } + + var ateLoop2NAF [33]int8 + ecc.NafDecomposition(big.NewInt(ateLoop), ateLoop2NAF[:]) + + var res GT + res.SetOne() + var prodLines [5]fields_bls24315.E4 + + var l1, l2 lineEvaluation + Qacc := make([]G2Affine, n) + Qneg := make([]G2Affine, n) + yInv := make([]frontend.Variable, n) + xOverY := make([]frontend.Variable, n) + for k := 0; k < n; k++ { + Qacc[k] = Q[k] + Qneg[k].Neg(api, Q[k]) + // TODO: point P=(x,O) should be ruled out + yInv[k] = api.DivUnchecked(1, P[k].Y) + xOverY[k] = api.Mul(P[k].X, yInv[k]) + } + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + // i = 32, separately to avoid an E24 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line to res) + Qacc[0], l1 = doubleStep(api, &Qacc[0]) + res.D1.C0.MulByFp(api, l1.R0, xOverY[0]) + res.D1.C1.MulByFp(api, l1.R1, yInv[0]) + + if n >= 2 { + // k = 1, separately to avoid MulBy034 (res × ℓ) + // (res is also a line at this point, so we use Mul034By034 ℓ × ℓ) + Qacc[1], l1 = doubleStep(api, &Qacc[1]) + + // line evaluation at P[1] + l1.R0.MulByFp(api, l1.R0, xOverY[1]) + l1.R1.MulByFp(api, l1.R1, yInv[1]) + + // ℓ × res + prodLines = *fields_bls24315.Mul034By034(api, l1.R0, l1.R1, res.D1.C0, res.D1.C1) + res.D0.C0 = prodLines[0] + res.D0.C1 = prodLines[1] + res.D0.C2 = prodLines[2] + res.D1.C0 = prodLines[3] + res.D1.C1 = prodLines[4] + + } + + if n >= 3 { + // k >= 2 + for k := 2; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + } + + // i = 30, separately to avoid a doubleStep + // (at this point Qacc = 2Q, so 2Qacc-Q=3Q is equivalent to Qacc+Q=3Q + // this means doubleAndAddStep is equivalent to addStep here) + if n == 1 { + res.Square034(api, res) + } else { + res.Square(api, res) + + } + for k := 0; k < n; k++ { + // l2 the line passing Qacc[k] and -Q + l2 = lineCompute(api, &Qacc[k], &Qneg[k]) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // Qacc[k] ← Qacc[k]+Q[k] and + // l1 the line ℓ passing Qacc[k] and Q[k] + Qacc[k], l1 = addStep(api, &Qacc[k], &Q[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + + for i := 29; i >= 1; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res.Square(api, res) + + switch ateLoop2NAF[i] { + case 0: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k] and l1 the tangent ℓ passing 2Qacc[k] + Qacc[k], l1 = doubleStep(api, &Qacc[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } + case 1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]+Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]+Q[k]) and Qacc[k] + Qacc[k], l1, l2 = doubleAndAddStep(api, &Qacc[k], &Q[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + case -1: + for k := 0; k < n; k++ { + // Qacc[k] ← 2Qacc[k]-Q[k], + // l1 the line ℓ passing Qacc[k] and Q[k] + // l2 the line ℓ passing (Qacc[k]-Q[k]) and Qacc[k] + Qacc[k], l1, l2 = doubleAndAddStep(api, &Qacc[k], &Qneg[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + default: + return GT{}, errors.New("invalid loopCounter") + } + } + + // i = 0 + res.Square(api, res) + for k := 0; k < n; k++ { + // l1 the line ℓ passing Qacc[k] and -Q[k] + // l2 the line ℓ passing (Qacc[k]-Q[k]) and Qacc[k] + l1, l2 = linesCompute(api, &Qacc[k], &Qneg[k]) + + // line evaluation at P[k] + l1.R0.MulByFp(api, l1.R0, xOverY[k]) + l1.R1.MulByFp(api, l1.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P[k] + l2.R0.MulByFp(api, l2.R0, xOverY[k]) + l2.R1.MulByFp(api, l2.R1, yInv[k]) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + + res.Conjugate(api, res) + + return res, nil +} + +// FinalExponentiation computes the exponentiation e1ᵈ +// where d = (p²⁴-1)/r = (p²⁴-1)/Φ₂₄(p) ⋅ Φ₂₄(p)/r = (p¹²-1)(p⁴+1)(p⁸ - p⁴ +1)/r +// we use instead d=s ⋅ (p¹²-1)(p⁴+1)(p⁸ - p⁴ +1)/r +// where s is the cofactor 3 (Hayashida et al.) +func FinalExponentiation(api frontend.API, e1 GT) GT { + const genT = ateLoop + result := e1 + + // https://eprint.iacr.org/2012/232.pdf, section 7 + var t [9]GT + + // easy part + // (p¹²-1)(p⁴+1) + t[0].Conjugate(api, result) + t[0].DivUnchecked(api, t[0], result) + result.FrobeniusQuad(api, t[0]). + Mul(api, result, t[0]) + + // hard part (api, up to permutation) + // Daiki Hayashida and Kenichiro Hayasaka + // and Tadanori Teruya + // https://eprint.iacr.org/2020/875.pdf + // 3(p⁸ - p⁴ +1)/r = (x₀-1)² * (x₀+p) * (x₀²+p²) * (x₀⁴+p⁴-1) + 3 + t[0].CyclotomicSquare(api, result) + t[1].Expt(api, result, genT) + t[2].Conjugate(api, result) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Conjugate(api, t[1]) + t[1].Mul(api, t[1], t[2]) + t[2].Expt(api, t[1], genT) + t[1].Frobenius(api, t[1]) + t[1].Mul(api, t[1], t[2]) + result.Mul(api, result, t[0]) + t[0].Expt(api, t[1], genT) + t[2].Expt(api, t[0], genT) + t[0].FrobeniusSquare(api, t[1]) + t[2].Mul(api, t[0], t[2]) + t[1].Expt(api, t[2], genT) + t[1].Expt(api, t[1], genT) + t[1].Expt(api, t[1], genT) + t[1].Expt(api, t[1], genT) + t[0].FrobeniusQuad(api, t[2]) + t[0].Mul(api, t[0], t[1]) + t[2].Conjugate(api, t[2]) + t[0].Mul(api, t[0], t[2]) + result.Mul(api, result, t[0]) + + return result +} + +// PairingCheck calculates the reduced pairing for a set of points and returns True if the result is One +// ∏ᵢ e(Pᵢ, Qᵢ) =? 1 +// +// This function doesn't check that the inputs are in the correct subgroup. See IsInSubGroup. +func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { + f, err := MillerLoop(api, P, Q) + if err != nil { + return GT{}, err + } + return FinalExponentiation(api, f), nil +} + +// doubleAndAddStep doubles p1 and adds p2 to the result in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3, x4, y4 fields_bls24315.E4 + var line1, line2 lineEvaluation + var p G2Affine + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute x4 = lambda2**2-x1-x3 + x4.Square(api, l2). + Sub(api, x4, p1.X). + Sub(api, x4, x3) + + // compute y4 = lambda2*(x1 - x4)-y1 + y4.Sub(api, p1.X, x4). + Mul(api, l2, y4). + Sub(api, y4, p1.Y) + + p.X = x4 + p.Y = y4 + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return p, line1, line2 +} + +// doubleStep doubles a point in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func doubleStep(api frontend.API, p1 *G2Affine) (G2Affine, lineEvaluation) { + + var n, d, l, xr, yr fields_bls24315.E4 + var p G2Affine + var line lineEvaluation + + // lambda = 3*p1.x**2/2*p.y + n.Square(api, p1.X).MulByFp(api, n, 3) + d.MulByFp(api, p1.Y, 2) + l.DivUnchecked(api, n, d) + + // xr = lambda**2-2*p1.x + xr.Square(api, l). + Sub(api, xr, p1.X). + Sub(api, xr, p1.X) + + // yr = lambda*(p.x-xr)-p.y + yr.Sub(api, p1.X, xr). + Mul(api, l, yr). + Sub(api, yr, p1.Y) + + p.X = xr + p.Y = yr + + line.R0.Neg(api, l) + line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) + + return p, line + +} + +// addStep adds two points in affine coordinates, and evaluates the line in Miller loop +// https://eprint.iacr.org/2022/1162 (Section 6.1) +func addStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, lineEvaluation) { + + var p2ypy, p2xpx, λ, λλ, pxrx, λpxrx, xr, yr fields_bls24315.E4 + // compute λ = (y2-y1)/(x2-x1) + p2ypy.Sub(api, p2.Y, p1.Y) + p2xpx.Sub(api, p2.X, p1.X) + λ.DivUnchecked(api, p2ypy, p2xpx) + + // xr = λ²-x1-x2 + λλ.Square(api, λ) + p2xpx.Add(api, p1.X, p2.X) + xr.Sub(api, λλ, p2xpx) + + // yr = λ(x1-xr) - y1 + pxrx.Sub(api, p1.X, xr) + λpxrx.Mul(api, λ, pxrx) + yr.Sub(api, λpxrx, p1.Y) + + var res G2Affine + res.X = xr + res.Y = yr + + var line lineEvaluation + line.R0.Neg(api, λ) + line.R1.Mul(api, λ, p1.X) + line.R1.Sub(api, line.R1, p1.Y) + + return res, line + +} + +// linesCompute computes the lines that goes through p1 and p2, and (p1+p2) and p1 but does not compute 2p1+p2 +func linesCompute(api frontend.API, p1, p2 *G2Affine) (lineEvaluation, lineEvaluation) { + + var n, d, l1, l2, x3 fields_bls24315.E4 + var line1, line2 lineEvaluation + + // compute lambda1 = (y2-y1)/(x2-x1) + n.Sub(api, p1.Y, p2.Y) + d.Sub(api, p1.X, p2.X) + l1.DivUnchecked(api, n, d) + + // x3 =lambda1**2-p1.x-p2.x + x3.Square(api, l1). + Sub(api, x3, p1.X). + Sub(api, x3, p2.X) + + // omit y3 computation + + // compute line1 + line1.R0.Neg(api, l1) + line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) + + // compute lambda2 = -lambda1-2*y1/(x3-x1) + n.Double(api, p1.Y) + d.Sub(api, x3, p1.X) + l2.DivUnchecked(api, n, d) + l2.Add(api, l2, l1).Neg(api, l2) + + // compute line2 + line2.R0.Neg(api, l2) + line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) + + return line1, line2 +} + +// lineCompute computes the line that goes through p1 and p2 but does not compute p1+p2 +func lineCompute(api frontend.API, p1, p2 *G2Affine) lineEvaluation { + + var qypy, qxpx, λ fields_bls24315.E4 + + // compute λ = (y2-y1)/(x2-x1) + qypy.Sub(api, p2.Y, p1.Y) + qxpx.Sub(api, p2.X, p1.X) + λ.DivUnchecked(api, qypy, qxpx) + + var line lineEvaluation + line.R0.Neg(api, λ) + line.R1.Mul(api, λ, p1.X) + line.R1.Sub(api, line.R1, p1.Y) + + return line + +} + +// ---------------------------- +// Fixed-argument pairing +// ---------------------------- +// +// The second argument Q is the fixed canonical generator of G2. +// +// Q.X.B0.A0 = 0x2f339ada8942f92aefa14196bfee2552a7c5675f5e5e9da798458f72ff50f96f5c357cf13710f63 +// Q.X.B0.A1 = 0x20b1a8dca4b18842b40079be727cbfd1a16ed134a080b759ae503618e92871697838dc4c689911c +// Q.X.B1.A0 = 0x16eab1e76670eb9affa1bc77400be688d5cd69566f9325b329b40db85b47f236d5c34e8ffed7536 +// Q.X.B1.A1 = 0x6e8c608261f21c41f2479ca4824deba561b9689a9c03a5b8b36a6cbbed0a7d9468e07e557d8569 +// Q.Y.B0.A0 = 0x3cdd8218baa5276421c9923cde33a45399a1d878d5202fae600a8502a29681f74ccdcc053b278b7 +// Q.Y.B0.A1 = 0x3a079c670190bb49b1bd21e10aac3191535e32ce99da592ddfa8bd09d57a7374ed63ad7f25e398d +// Q.Y.B1.A0 = 0x1b38dd0c5ec49a0883a950c631c688eb3b01f45b7c0d2990cd99052005ebf2fa9e7043bbd605ef5 +// Q.Y.B1.A1 = 0x495d6de2e4fed6be3e1d24dd724163e01d88643f7e83d31528ab0a80ced619175a1a104574ac83 + +// MillerLoopFixed computes the single Miller loop +// fᵢ_{u,g2}(P), where g2 is fixed. +func MillerLoopFixedQ(api frontend.API, P G1Affine) (GT, error) { + + var ateLoop2NAF [33]int8 + ecc.NafDecomposition(big.NewInt(ateLoop), ateLoop2NAF[:]) + + var res GT + res.SetOne() + + var l1, l2 lineEvaluation + var yInv, xOverY frontend.Variable + yInv = api.DivUnchecked(1, P.Y) + xOverY = api.Mul(P.X, yInv) + + // Compute ∏ᵢ { fᵢ_{x₀,Q}(P) } + // i = 31, separately to avoid an E24 Square + // (Square(res) = 1² = 1) + + // k = 0, separately to avoid MulBy034 (res × ℓ) + // (assign line(P) to res) + res.D1.C0.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[0][31], B1: precomputedLines[1][31]}, + xOverY) + res.D1.C1.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[2][31], B1: precomputedLines[3][31]}, + yInv) + + // i = 30 + res.Square034(api, res) + // line evaluation at P + l1.R0.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[0][30], B1: precomputedLines[1][30]}, + xOverY) + l1.R1.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[2][30], B1: precomputedLines[3][30]}, + yInv) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P + l2.R0.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[4][30], B1: precomputedLines[5][30]}, + xOverY) + l2.R1.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[6][30], B1: precomputedLines[7][30]}, + yInv) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + + for i := 29; i >= 0; i-- { + // mutualize the square among n Miller loops + // (∏ᵢfᵢ)² + res.Square(api, res) + + if ateLoop2NAF[i] == 0 { + // line evaluation at P + l1.R0.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[0][i], B1: precomputedLines[1][i]}, + xOverY) + l1.R1.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[2][i], B1: precomputedLines[3][i]}, + yInv) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + } else { + // line evaluation at P + l1.R0.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[0][i], B1: precomputedLines[1][i]}, + xOverY) + l1.R1.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[2][i], B1: precomputedLines[3][i]}, + yInv) + + // ℓ × res + res.MulBy034(api, l1.R0, l1.R1) + + // line evaluation at P + l2.R0.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[4][i], B1: precomputedLines[5][i]}, + xOverY) + l2.R1.MulByFp(api, + fields_bls24315.E4{B0: precomputedLines[6][i], B1: precomputedLines[7][i]}, + yInv) + + // ℓ × res + res.MulBy034(api, l2.R0, l2.R1) + } + } + + res.Conjugate(api, res) + + return res, nil +} + +// PairFixedQ calculates the reduced pairing for a set of points +// e(P, g2), where g2 is fixed. +// +// This function doesn't check that the inputs are in the correct subgroups. +func PairFixedQ(api frontend.API, P G1Affine) (GT, error) { + f, err := MillerLoopFixedQ(api, P) + if err != nil { + return GT{}, err + } + return FinalExponentiation(api, f), nil +} diff --git a/std/algebra/sw_bls24315/pairing_test.go b/std/algebra/native/sw_bls24315/pairing_test.go similarity index 88% rename from std/algebra/sw_bls24315/pairing_test.go rename to std/algebra/native/sw_bls24315/pairing_test.go index c7bb53b4a5..216bf0b345 100644 --- a/std/algebra/sw_bls24315/pairing_test.go +++ b/std/algebra/native/sw_bls24315/pairing_test.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" "github.com/consensys/gnark/test" ) @@ -127,6 +127,37 @@ func TestTriplePairingBLS24315(t *testing.T) { } +type pairingFixedBLS315 struct { + P G1Affine `gnark:",public"` + pairingRes bls24315.GT +} + +func (circuit *pairingFixedBLS315) Define(api frontend.API) error { + + pairingRes, _ := PairFixedQ(api, circuit.P) + + mustbeEq(api, pairingRes, &circuit.pairingRes) + + return nil +} + +func TestPairingFixedBLS315(t *testing.T) { + + // pairing test data + P, _, _, pairingRes := pairingData() + + // create cs + var circuit, witness pairingFixedBLS315 + circuit.pairingRes = pairingRes + + // assign values to witness + witness.P.Assign(&P) + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) + +} + // utils func pairingData() (P bls24315.G1Affine, Q bls24315.G2Affine, milRes bls24315.E24, pairingRes bls24315.GT) { _, _, P, Q = bls24315.Generators() diff --git a/std/algebra/native/sw_bls24315/precomputations.go b/std/algebra/native/sw_bls24315/precomputations.go new file mode 100644 index 0000000000..0daaab582a --- /dev/null +++ b/std/algebra/native/sw_bls24315/precomputations.go @@ -0,0 +1,324 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sw_bls24315 + +import "github.com/consensys/gnark/std/algebra/native/fields_bls24315" + +// precomputed lines going through Q and multiples of Q +// where Q is the fixed canonical generator of G2 +// +// Q.X.B0.A0 = 0x2f339ada8942f92aefa14196bfee2552a7c5675f5e5e9da798458f72ff50f96f5c357cf13710f63 +// Q.X.B0.A1 = 0x20b1a8dca4b18842b40079be727cbfd1a16ed134a080b759ae503618e92871697838dc4c689911c +// Q.X.B1.A0 = 0x16eab1e76670eb9affa1bc77400be688d5cd69566f9325b329b40db85b47f236d5c34e8ffed7536 +// Q.X.B1.A1 = 0x6e8c608261f21c41f2479ca4824deba561b9689a9c03a5b8b36a6cbbed0a7d9468e07e557d8569 +// Q.Y.B0.A0 = 0x3cdd8218baa5276421c9923cde33a45399a1d878d5202fae600a8502a29681f74ccdcc053b278b7 +// Q.Y.B0.A1 = 0x3a079c670190bb49b1bd21e10aac3191535e32ce99da592ddfa8bd09d57a7374ed63ad7f25e398d +// Q.Y.B1.A0 = 0x1b38dd0c5ec49a0883a950c631c688eb3b01f45b7c0d2990cd99052005ebf2fa9e7043bbd605ef5 +// Q.Y.B1.A1 = 0x495d6de2e4fed6be3e1d24dd724163e01d88643f7e83d31528ab0a80ced619175a1a104574ac83 + +var precomputedLines [8][32]fields_bls24315.E2 + +func init() { + precomputedLines[0][31].A0 = "2781727895159262619460256409706287591924453463655172989784955348152343938146585905490452975032" + precomputedLines[0][31].A1 = "20743040417459989930623561845602589771631449600215198323538516287726550241235287115679407786845" + precomputedLines[1][31].A0 = "37827392216841557097204025035278323275223365579534097097252687825026777262856397606000655067832" + precomputedLines[1][31].A1 = "10178945294409318168940915139478758628116906709732642430558232843003175855355513179037001936084" + precomputedLines[2][31].A0 = "9268466886752397507456884855877792769302861346475683862931987092075954986067022396137895751064" + precomputedLines[2][31].A1 = "35776315690986656338137813349452405648722069133391461576095108348893303823506705187450710629560" + precomputedLines[3][31].A0 = "34392177303163153605588192336284221250980197014684318869994241271134150019763838507446731038390" + precomputedLines[3][31].A1 = "13194024163631068600629692688955799202181581968937073159572603162230683537538603191556380343912" + precomputedLines[4][30].A0 = "36923414814354175715565433480702682153009048953259576345279330157485540154979756441583164158537" + precomputedLines[4][30].A1 = "18962102292053448404402128044806379973302052816699551011525769217911333851891055231394209346724" + precomputedLines[5][30].A0 = "1877750492671881237821664855130646469710136837380652237811597680611106830269944741072962065737" + precomputedLines[5][30].A1 = "29526197415104120166084774750930211116816595707182106904506052662634708237770829168036615197485" + precomputedLines[6][30].A0 = "30436675822761040827568805034531176975630641070439065472132298413561929107059319950935721382505" + precomputedLines[6][30].A1 = "3928827018526781996887876540956564096211433283523287758969177156744580269619637159622906504009" + precomputedLines[7][30].A0 = "5312965406350284729437497554124748493953305402230430465070044234503734073362503839626886095179" + precomputedLines[7][30].A1 = "26511118545882369734395997201453170542751920447977676175491682343407200555587739155517236789657" + precomputedLines[0][30].A0 = "461236740181425465397659913431799545589263748526554273770288722107013015330673780414530164491" + precomputedLines[0][30].A1 = "22109606767856473710298596353168431563366439430064228381353601705604249164343745550841848509426" + precomputedLines[1][30].A0 = "33387760941962928997831938618853079610298494109458657444811760367567408985369697315043385588614" + precomputedLines[1][30].A1 = "35700314461735701196263231559108204373201038473833731959663398341703231623996071740173212310084" + precomputedLines[2][30].A0 = "10329378385221266344532466551506767851778327290324740493490344241246113863412295958246960179810" + precomputedLines[2][30].A1 = "6604106544210738470847737536373068472609504210483433139402660739765438480505001427938032820014" + precomputedLines[3][30].A0 = "56176824678400801751944537915603387431936789382030871047984526853009955441483023923669464381" + precomputedLines[3][30].A1 = "27938806984042648796501054486346250447762374553429279197564045218301634928056111595725018396833" + precomputedLines[0][29].A0 = "36207223000216692209594824085685861604898944305483774833844371014395036671197665701483312391612" + precomputedLines[0][29].A1 = "13025479363374537690881509758209881371297964971897024462202786841749718333089577298946918188049" + precomputedLines[1][29].A0 = "30876654915508348462099814699006435195390585650953934460637247464710825669079097318727044955324" + precomputedLines[1][29].A1 = "19728007336919325853909059021906555308384288457370960345674978745741589196394349315480630813174" + precomputedLines[2][29].A0 = "909311813858823350494627504711134525979860239017585631521289842377706657228953582774346888578" + precomputedLines[2][29].A1 = "10815646958156379379958903143623204614357516989953661486788700840240635604266238610799737236225" + precomputedLines[3][29].A0 = "37176231655166055487771177626322523043108110244212752571578415150533964508255208301128944686425" + precomputedLines[3][29].A1 = "11787323836696620816459020008980422057328092899400811297543719827323850831750048769709960317713" + precomputedLines[0][28].A0 = "35526633032731215739646583827960776336242394556588137819759929534542416292367136756261285825267" + precomputedLines[0][28].A1 = "36077665212540379762430984913125162349272623724469806995440805947904970285562898975048807438042" + precomputedLines[1][28].A0 = "9831914100430604698284412987456612279245996036568435658393063771186842655014269332391936465552" + precomputedLines[1][28].A1 = "27803990512020310024520249171063066302535110320329458505559586380744692999620567118441755982575" + precomputedLines[2][28].A0 = "36842381519646369093383442662162685647092011333154011205746704344697763573995075808581470975686" + precomputedLines[2][28].A1 = "23738273962255737410559068577145745540405127638282298581175256942234085985037760246920691581027" + precomputedLines[3][28].A0 = "19343978517730572718860522682802814216593994929782009605063447864313634694271841044925479570531" + precomputedLines[3][28].A1 = "18116753199019354688284163829906760866044736814208423320319015511125968056743303656129386441866" + precomputedLines[0][27].A0 = "4139820564413273660975376065894434885439027728093365575251324515220070481565883126170375341910" + precomputedLines[0][27].A1 = "8911617021784386057126708729119542632362492047656263377663066761297578404636216402825484775928" + precomputedLines[1][27].A0 = "12814690092798145943598085489221043532460227672773271375898281963005211525225801195783249931850" + precomputedLines[1][27].A1 = "16718273049980440280025910942614061717135303745636448098710467300253988904425074671319519741118" + precomputedLines[2][27].A0 = "16594655699271726571663255998433955793737368741212675757555410571776232951818397604519762189474" + precomputedLines[2][27].A1 = "13367696895090740812829664711766537167099840996127082529013954488222000074820602120439148648099" + precomputedLines[3][27].A0 = "6420076358944069431155841265780960311058634777614248960904550355452537151743611496161985131365" + precomputedLines[3][27].A1 = "28404447674004042853398410237840282183092464989903366148759784000427834047150506348995795717809" + precomputedLines[0][26].A0 = "37604353693033877742048049322576348702586207874617868492276241315554724060655783733052847409121" + precomputedLines[0][26].A1 = "39261971732539568838334542788443899609780416935149888558783126155588190446903694995730051900365" + precomputedLines[1][26].A0 = "38050604325386366053843363684112564677444199391886450001583256090460002797590842990124536031532" + precomputedLines[1][26].A1 = "12486901814459443970000088039080266608927306267447595061288549132864309252725791153409686329572" + precomputedLines[2][26].A0 = "8669205977841268940729248303373678041304029372967447821499144672052231998430251413982068112691" + precomputedLines[2][26].A1 = "33620850128131642869212216609407998185833602518893271009726530757748512592480536577630033061823" + precomputedLines[3][26].A0 = "34737275666822811008205566641686065222055453870944740858999888281078924292513175782852078323733" + precomputedLines[3][26].A1 = "18254080092181842209873918423427595803784921393769444082922571074940602698162386992379649096313" + precomputedLines[0][25].A0 = "33494531665767237337261136359671192283112751425803623165525167804892931224732622129607693307742" + precomputedLines[0][25].A1 = "9278100857534226025641330955131726409005382296989257273200584893537287401626951197952759674493" + precomputedLines[1][25].A0 = "33240375891711195631532877073721510948230877401110722965994143251798304959718236446828310159307" + precomputedLines[1][25].A1 = "26079000065055324834065204274857645020355317243833426222126668524329876558926142414599523276665" + precomputedLines[2][25].A0 = "15859294583156632783490072933932976874325809775899861051860187063056143273724358578019071960208" + precomputedLines[2][25].A1 = "12989259388599173305353446328696899481319414347442645028403869152186327038897262252520273540653" + precomputedLines[3][25].A0 = "30910077587886253917237656134320384393297765416421847027710187087988032545714250816681314685395" + precomputedLines[3][25].A1 = "1716077440394980974184818665053901533075435418008868965829851716169023435817436946951992897145" + precomputedLines[0][24].A0 = "38022987570716725690786302461727029651251809425645564834320979952332491325234334275367555849402" + precomputedLines[0][24].A1 = "22382089991707690239526080202236231301691793169321624277174049168646187848585487410406877412638" + precomputedLines[1][24].A0 = "10967122452099721484284940822576738208730793239658139693691586644357653971101083380482640239541" + precomputedLines[1][24].A1 = "6296200253699122091062733628573824162005116152394040262487245126738085376673977020081066648197" + precomputedLines[2][24].A0 = "15230995971830076859540762883104721497290684779392350376519365839020422313547195786628566027018" + precomputedLines[2][24].A1 = "47156447686332804224749056692281334136512352840202610295766985993853353666738313542931677220" + precomputedLines[3][24].A0 = "849643811912878419391768807241710728614181448531746373457483263920503655984689592118094446534" + precomputedLines[3][24].A1 = "23339306528845812168401482536603175735911774908855336545906964282095492870899177538846007151833" + precomputedLines[0][23].A0 = "14359525235776738828677656394202177198450893605810206196321461456718335209603231189046101545786" + precomputedLines[0][23].A1 = "21905252054027406652457867781137328983540923495314965272151593431849687527300735227939466547289" + precomputedLines[1][23].A0 = "39382244317721003918493545183150607399986933016526387926104862429213314789415250721275756151152" + precomputedLines[1][23].A1 = "106243383565730591022760449063524503954123458163223854914349331658072014475496462036898439089" + precomputedLines[2][23].A0 = "33934942169602274641760331750320105797404450932919603382251206008276251972260298974387919552378" + precomputedLines[2][23].A1 = "33198016940814999542946565768547235437681716979391065640098597621267869123443352758585470345914" + precomputedLines[3][23].A0 = "5534548901822064116777190944148021233508971316189553493184194741722058457018036124061400423176" + precomputedLines[3][23].A1 = "12450586196857202466430334026063247534362616816769777985477490853948574226500709928953928483681" + precomputedLines[0][22].A0 = "24116766928000882868438866910633589391418836325075997880936288605683749233917273601430913867031" + precomputedLines[0][22].A1 = "39641459595363522752499405527453498563991561082882708933775085712799158900143555067567301537983" + precomputedLines[1][22].A0 = "33598915614427410979683100795843864286261129821019053262288609702298499715582702732617214231187" + precomputedLines[1][22].A1 = "5262124792700221839188027462135374454489777679898301160495199180366034600060177822615081319085" + precomputedLines[2][22].A0 = "21516158337900277544913378192323896600527462999773790482202122032899914748600997674871376133562" + precomputedLines[2][22].A1 = "26622692400379167273620273511794447392112917166155187611001125357341264621174997045508900222642" + precomputedLines[3][22].A0 = "24310512702047941666047912370112427276891086638312249313320034339578121005174457165501307529713" + precomputedLines[3][22].A1 = "953333043881074040029315836857160269189732242709815118819954273558410764837343910692838326858" + precomputedLines[4][22].A0 = "27797173901566533295835843427326389241696442503453326091896689429213362671225250697676067935913" + precomputedLines[4][22].A1 = "7517895773302930044260266502468142246264386231422282398327453567847615104743394934091649459894" + precomputedLines[5][22].A0 = "36795251871685034478895150962528633739369363057215926071434219813535609447328478237410192307911" + precomputedLines[5][22].A1 = "18631169883627265440413151432606515200529868645898343520010489641957929541123850319019405475583" + precomputedLines[6][22].A0 = "11705862010716127199557613402683638512734030914536324890969280998879266545619285486128282095245" + precomputedLines[6][22].A1 = "4136600114848342094372402413941254036790131333069895448041593703036415178147550851484619523807" + precomputedLines[7][22].A0 = "37570867334046458225556776814838531593062522405010872509375689812482340806086772547461756774989" + precomputedLines[7][22].A1 = "12705940194511871546737770092253197245319394453873275898868195705204742918732917851325735189413" + precomputedLines[0][21].A0 = "5564969130128611283443654088819047572325892968389373131978354554670476140461835693919147073797" + precomputedLines[0][21].A1 = "10405661676677191369024323590229828775160908784965280281109537362658768266187777657999846903573" + precomputedLines[1][21].A0 = "30443493416573592906568262842998768428837720507865461213890972799050694066318812305869741893672" + precomputedLines[1][21].A1 = "15811528795176940740940616950789290866928742392374901158067071809715820836098654521829894497774" + precomputedLines[2][21].A0 = "36191875581225300799953049486633601639219348428995002152491715141504244839468688455045103963193" + precomputedLines[2][21].A1 = "10741867649125824941445258223818249329541732185217053214056036134013544852877186210783300006280" + precomputedLines[3][21].A0 = "22025601763169617275138986640849565558497176491908055865179892965288740712151900077822995530707" + precomputedLines[3][21].A1 = "35921234000830599852062445984786207077604495291021339304696660534730772057827303879728374813667" + precomputedLines[0][20].A0 = "31292015151909897416862294383130387802270356153332611876479302527747121889695346094558281778626" + precomputedLines[0][20].A1 = "13930542767389197671780209812978098488006924391941364140133977004723504633676932297996587248462" + precomputedLines[1][20].A0 = "34946579270084510349797924193532992671175059737423804616586716993078605078383939413797416120629" + precomputedLines[1][20].A1 = "21600338660960091899188574723370720179415828875563944649700209097980124005578421836080183279970" + precomputedLines[2][20].A0 = "18169712894327601399213351769231157132715149760944734200744840366750506752688360819382613874084" + precomputedLines[2][20].A1 = "7645950143787026813902842681149678177583107749869670405693443743494687966581375097230624951328" + precomputedLines[3][20].A0 = "10450687616941180280565867644324481147392536154961528225571993709714382825608157331038868456922" + precomputedLines[3][20].A1 = "31234744532381175574690832289575864589990265796762799317453912537673129095374871060349189087186" + precomputedLines[4][20].A0 = "18244625572952092208855342985581672626704810583380789315491039814709180238553735342592467104584" + precomputedLines[4][20].A1 = "1812917121253519219272380100939626114961474223442348473164158509370884376991775922540678424199" + precomputedLines[5][20].A0 = "4189079746200411362955037024732870541887365528356539473660567146507288281566342955731832126006" + precomputedLines[5][20].A1 = "31900638572412423733941232119698174782113881083785624385877953164719457347151597910976214458123" + precomputedLines[6][20].A0 = "402770421806183780722383435479835361895687265374460829694721879250313327019707581203616025256" + precomputedLines[6][20].A1 = "5338804131974300149478741839324725635802753248255198206093031524755093495029756469870209601565" + precomputedLines[7][20].A0 = "5409895333826784292644146025721399648968781270561122220855341836188047092905081274597523342518" + precomputedLines[7][20].A1 = "39076474830253731650659728277389802583219194312090255919170765687240428468792575065044338782628" + precomputedLines[0][19].A0 = "31114053283660962755867941195430012349215244783592757014237105772729165951036529677535655764387" + precomputedLines[0][19].A1 = "26822526825655815954045850138735075943818643161753815292894141881344419682839348625362034776193" + precomputedLines[1][19].A0 = "36060932166827797053887087007312810483855750454789428793677012723843064422582690020032456436137" + precomputedLines[1][19].A1 = "32098851055206798994594140778152652170455628508809132360573721679531263277066224209526131544901" + precomputedLines[2][19].A0 = "33801674850098656199124087787875248358328904798265960397567489707101908735393438520124157033615" + precomputedLines[2][19].A1 = "30388543397528963488853967085247778789266555304165343706837114367817951380548395345416331818275" + precomputedLines[3][19].A0 = "14891046073431502444741776433179748283399031556474785478201428859843785153141080394172107380767" + precomputedLines[3][19].A1 = "14372927041143425217372675069945959163058516628206839468882738798600008977366241082678710901924" + precomputedLines[0][18].A0 = "13414738086623137406615537866474474463183085838797716105830381924959826685725541209811893282037" + precomputedLines[0][18].A1 = "3736434517030899315256895169222446219292367752471290470925591098446103735370391383316426906671" + precomputedLines[1][18].A0 = "27457662102772747875857281126160470929918189643676828927161142822635201230511806204158445258846" + precomputedLines[1][18].A1 = "8687723610396450806827779759734065833949285071849089469933023783233391375441334053580973117229" + precomputedLines[2][18].A0 = "25722520882778153104563773560791697643948710509606670528510675267399602342406424097692271045055" + precomputedLines[2][18].A1 = "29535973152591303515555109393093740342020112624245910269420195229605409967402121816379425371615" + precomputedLines[3][18].A0 = "27506837773009407303785749209983697773749280643569416582960145475677280879470617771597342608846" + precomputedLines[3][18].A1 = "36141120534019515011681579131673899522338008246732005608412672064797529071634381472272325079525" + precomputedLines[0][17].A0 = "11717729000798988755422327893657787502661742987483544438175664467582020736874329899394245827688" + precomputedLines[0][17].A1 = "35532398135757357842389195803912115274827231292215469213918665001193807391350504744207914512285" + precomputedLines[1][17].A0 = "27057869421122262463151296129607612527238166833416480317951849056817703415117856038870841075921" + precomputedLines[1][17].A1 = "35943495644629983092269742142112256669032715123998196686602605147199114366505422512978205180155" + precomputedLines[2][17].A0 = "20985982115845415422566222393633171703815426175558593525040069603963863193710495164216772967648" + precomputedLines[2][17].A1 = "19728721577898825674660676901527073362349099944485394582072672518755164176381336267403570967557" + precomputedLines[3][17].A0 = "32208590718373937996107859595531446484772546454309502340935577766089544713928955003503272168584" + precomputedLines[3][17].A1 = "35312467176506667889629624386719510190282306587106852325544487949295619725637672658623115197281" + precomputedLines[0][16].A0 = "23036550774181972055464268404581204486105348557110170797706229138632090821339780471252296036728" + precomputedLines[0][16].A1 = "25656367780631864867257121578529031509814056270395128874148266549825794784067928329183854045416" + precomputedLines[1][16].A0 = "31939375271100011054085397962761554649179472285416036905149533036049213890060862213696975178985" + precomputedLines[1][16].A1 = "29110450946744370523652481595026553804574080247283689100346029405455274064035699779161276563575" + precomputedLines[2][16].A0 = "11048652868656229622246713598468676653559194082394139171140334872854204435365270750622900970783" + precomputedLines[2][16].A1 = "5501045731158119241610343653715295192870010364988359296127703174125217661873062886285848290874" + precomputedLines[3][16].A0 = "10206617939281427436790821175466013132135463542793696400290736439226081160734912356344269495097" + precomputedLines[3][16].A1 = "23813366313682011241706618836273528182268227940742050117953442494772159753351145932582435276846" + precomputedLines[0][15].A0 = "2133957659689706241276078671260919836456005285642708712976150803445255403162988177486926528351" + precomputedLines[0][15].A1 = "2488865112498676099110995354357285551752828319577073780941239150787924374728797950742549805623" + precomputedLines[1][15].A0 = "31167162327730814661364104237093800994496436807804786606093723607382273791382056294755346487358" + precomputedLines[1][15].A1 = "21506110711357040293361847052932809781529370612490152336911389948883578220495820618999597994791" + precomputedLines[2][15].A0 = "20000323660737379451664253537658499845423809848107930420298365558711703360899527092122237446566" + precomputedLines[2][15].A1 = "31047060047836132392761977163028756479088420162728326540865842433353499566492521695141939840178" + precomputedLines[3][15].A0 = "17173959277400293388979886131899176887760680318927911053354913834493273771462738023327301110638" + precomputedLines[3][15].A1 = "17690847038526634082638452638406157619665265984714651835044148846270139683871674918989914991861" + precomputedLines[0][14].A0 = "8111539519530413403803687927346187442019219042699904805461482616010202496993719690007189331625" + precomputedLines[0][14].A1 = "448874242454015329502057427657609947097344721600901494139407558636947535507148044601679475338" + precomputedLines[1][14].A0 = "6218224846167399692495961484545008252888979230943249892046885179375462584517565030060699453260" + precomputedLines[1][14].A1 = "36063628801519082441041995499662529065568543366380306753785812995710345234643156265593260199398" + precomputedLines[2][14].A0 = "7107829277058020196592080645093919681504961945037352092824943592023331449849925180695967870952" + precomputedLines[2][14].A1 = "18521245061965915188690044864549452571768992417060630092786539234014150527836638102193541186930" + precomputedLines[3][14].A0 = "27949020407470802004123788569285881348200653440280118520071211341660295914387006718083595769450" + precomputedLines[3][14].A1 = "26600270865790640429498738283238103998943985868581555514403667340147378499140645385091149272320" + precomputedLines[0][13].A0 = "6263321087330035064543049353138514314872408573889371232977305847809668975578000603666123878809" + precomputedLines[0][13].A1 = "4437541035845212988887471619223261566163168416138845790130323284483493293456020730759985439616" + precomputedLines[1][13].A0 = "25957435781877328912903401606972387309657750887055914042318386430795058962632467015667236108737" + precomputedLines[1][13].A1 = "36484840707780237190652084958089403575855195433133705409171537067380719627822308356069547120407" + precomputedLines[2][13].A0 = "14078978732105111410266385516296065543805048769606604450952902572185901498532149408619418152615" + precomputedLines[2][13].A1 = "17105517229330563803416104936775670022295399858733502355816791347049041093723928257004830886314" + precomputedLines[3][13].A0 = "10806899777627309766583532112933795669722060146450781380905829513515561385861326722929487755970" + precomputedLines[3][13].A1 = "24744536129411616919025527263026014756889669153574689022224600256459279846694316013120809733959" + precomputedLines[0][12].A0 = "23745999228289863953038255466656692436128367145219424905260990630312803229022735746030137733828" + precomputedLines[0][12].A1 = "33043785396578354783823233801388646630365939289922797159528247950882848308520829194615208228944" + precomputedLines[1][12].A0 = "35667126958325848321558983643011193398385133756466623685740557583313309860496785085555779907345" + precomputedLines[1][12].A1 = "29846895267525552008785346486773063046705602516926641663402481851740686623931460069730648217240" + precomputedLines[2][12].A0 = "29597284782639667059757928445753529939787392264151304171730433038218711146814859657965902047945" + precomputedLines[2][12].A1 = "25715886907316254047489461998769334219356787282119170086215917458712091462735832245826070514160" + precomputedLines[3][12].A0 = "292375600601563939084165235044214945947839844260137135633138762794392833955576677040018470614" + precomputedLines[3][12].A1 = "10782898790206765796851940396216761698327803867150223664950414174800616001660838411107097732513" + precomputedLines[0][11].A0 = "35920784748362887997150609964824872451642503710077341491261872934893540761715225937818741663141" + precomputedLines[0][11].A1 = "29876228319052725073594000323621515998922348469513361981163063601203155860569381774837315962847" + precomputedLines[1][11].A0 = "4212495452122458330182425723774512440189759514626412374365415039223584356123780613299473198610" + precomputedLines[1][11].A1 = "1924421839004317581491579214074183545725849260222256043217791340260606843958334802109284247328" + precomputedLines[2][11].A0 = "32450120564888057044978023267757897386115283935203239564918265071119319854711736454073065716299" + precomputedLines[2][11].A1 = "14627232376477158058719961423236457189108246940355146459764098157667232685908054707942516603735" + precomputedLines[3][11].A0 = "33246276106474112244032779374880539276361226771361852668236976949574816261665744531948520234458" + precomputedLines[3][11].A1 = "16738514735881697141673102901510656079850601540955222380817787014529146699763159679255943897522" + precomputedLines[0][10].A0 = "16286015288723309133954051516717312399685307247774691544684805104457873255509108606909781786193" + precomputedLines[0][10].A1 = "21508020735163668634573477663793665084606427145298809578845346929500021954680174684114047156004" + precomputedLines[1][10].A0 = "22170901052984276563045981224777970943757111625131992851467733466895567389186002803562894657620" + precomputedLines[1][10].A1 = "20337170937548070048567203165477927964771253618042537150883255191215948265952196359827184550330" + precomputedLines[2][10].A0 = "29906991281551506621636532931204159502175916577305006921380536441512844623285740508751212126233" + precomputedLines[2][10].A1 = "24335676534041396051393095052783560187088880042200281495643305117568953364274944821505673996016" + precomputedLines[3][10].A0 = "13957288315623373478703709704733087043362935862231059168890530400805564895009003777202682694415" + precomputedLines[3][10].A1 = "1223871989653055119559112084775210590118007076675188522413134903997074491991246945427276141612" + precomputedLines[0][9].A0 = "33218255529131975999088663058330655956045567133753943273653953175057884050772880832965922717131" + precomputedLines[0][9].A1 = "10934630289844167494091225248241728899411916373020488899801686481348333710046438403201308407898" + precomputedLines[1][9].A0 = "31171955139286976761621361169471847318936284421194495870537234475204313307611429963373436954972" + precomputedLines[1][9].A1 = "37762207515030513011712878404932937122821781320720198379211210922604824640148215745984360827381" + precomputedLines[2][9].A0 = "7639977714782287724858945813838810686127449004710130755077378818504348680931360176328235852245" + precomputedLines[2][9].A1 = "14201260674771382381524759204339476708757557284778093460490840195148895537922738096328133771620" + precomputedLines[3][9].A0 = "21945500212457128574884187136695872308120878917315713873080347140140151745839200655885062130431" + precomputedLines[3][9].A1 = "3809341615215073238952696293307895945619207782803386115293091067210212389340888964722673077811" + precomputedLines[0][8].A0 = "26110965917253549666256221802911467147812063004190184340810268406452869899534449242111993795498" + precomputedLines[0][8].A1 = "28665625022129116235910532291369957987707633646475081662386334109210571010387515497812169012502" + precomputedLines[1][8].A0 = "26451788831222611657973874982635393576669681250091787126040136726448053584740364938778739214828" + precomputedLines[1][8].A1 = "8092412749874681496481925780617715361704844759940353397789192243839662493692124489294082295158" + precomputedLines[2][8].A0 = "9270321145350674224652554836320317593554375249479983474961603816700059783004381286965429390482" + precomputedLines[2][8].A1 = "19460616333360456802646816585485122591767150252894632440673689080046716133956675348791450575734" + precomputedLines[3][8].A0 = "35048749147662922007182020957342969178799873311637936228343821723495325237711290970289680914054" + precomputedLines[3][8].A1 = "18003681420666747728710171023823502122601279722022718669136923587542584794890757704574215654811" + precomputedLines[0][7].A0 = "35920816427330927972525474782764888756472968812350270443756178404983645205402827397794728641419" + precomputedLines[0][7].A1 = "15994857004777387801523732390459086789329045070561619970738011877659412905592886788721336065320" + precomputedLines[1][7].A0 = "32307047297034634147920679145288257882068272966784093493163171956462800709378566240125774734459" + precomputedLines[1][7].A1 = "32503166008950505728027188500963046174019071051722346957049170811959599106835852367956077898450" + precomputedLines[2][7].A0 = "17188922968435351423464680881182634768443529491819980925443673922398748568760440087398413189739" + precomputedLines[2][7].A1 = "8203181179247593599607931817400907357630217341514412542051793645833385459473298497518390881747" + precomputedLines[3][7].A0 = "36440689787034268806769917957707530654398395198351329488854970878100680130641727965973755953660" + precomputedLines[3][7].A1 = "30679528465263273472825625188094330379677454172085063888065203363920746866308504331426136016632" + precomputedLines[0][6].A0 = "9746562881158199581754622002118658451775449778442538123773942812978675925619714307701608614265" + precomputedLines[0][6].A1 = "25077682231375998062120150498797738973764872256742159666379011754585511911704890059225842456765" + precomputedLines[1][6].A0 = "38149811606296469898249554197414370147831935200262058720129887805317716991590322718652394876130" + precomputedLines[1][6].A1 = "30696438177578316186486393145442061351477812398526658739011576551023953565377188952667300217541" + precomputedLines[2][6].A0 = "20406102086667904306287403007884612827895803656375390480800545915711192967660799055320459684169" + precomputedLines[2][6].A1 = "36438504904653507065337738487839489560368654316652823550153237168286541361435234899569132772549" + precomputedLines[3][6].A0 = "37987488421951908051031816262335108334339298854182743404335685414421418604813976821688878536348" + precomputedLines[3][6].A1 = "32318146974876420645374647678669067270153844319876894354662919353352967207186790225682505637696" + precomputedLines[0][5].A0 = "22086129509694338627689225821291303317065063099098758223153192146808732747265013048748474294784" + precomputedLines[0][5].A1 = "37713807874245437781971305473721596363261815037521027566687799823515739077799227119925269839477" + precomputedLines[1][5].A0 = "18305854523413278457362868424577397873627830576907357622543815286052456245298666405508603410248" + precomputedLines[1][5].A1 = "300352913086796864066129281465076525456392732485761752807122219597254739676369156711180503852" + precomputedLines[2][5].A0 = "20686954459670710909377147635558654224518784080467327251844688248566032911966476571520179024492" + precomputedLines[2][5].A1 = "32418278882492425990451033275415963330415246847268149115611295027327117819375976636489500759285" + precomputedLines[3][5].A0 = "29528080515296253150121160848487841289421169041905693546117487538395739713695784460135788925455" + precomputedLines[3][5].A1 = "14725369051989293428532315695564495727508711043089266443831770444687199384028823901585006515157" + precomputedLines[0][4].A0 = "11458755346771378923831390021279188897090759918109210611404640522482969121019340913474362814450" + precomputedLines[0][4].A1 = "5491646844778900481050609932477053446856385605326869179687773192043234983862925376426042823062" + precomputedLines[1][4].A0 = "2195792007395234473305294229497731661918165676767453580741033197739477651120884434849445481873" + precomputedLines[1][4].A1 = "35637486887615981385928236107906602309261340786681513612138705991192128208559403292517003522358" + precomputedLines[2][4].A0 = "7476711848655575935584455007742858141977887239768917879896257564469461339953893938487413265919" + precomputedLines[2][4].A1 = "4731493677401477466160641225706901688803319007881509349892520157965552879783809521430335296113" + precomputedLines[3][4].A0 = "22560035734391637122977009698050609219219528141199101330022470988252363349921328539598517264573" + precomputedLines[3][4].A1 = "22016994390256035220432108516137296693575079484469774730435120240206101824130335022690522606673" + precomputedLines[0][3].A0 = "561026844021162847797910825643502565986545760969197833602141452550737812367896989367376646678" + precomputedLines[0][3].A1 = "11721460935985564351661434572740369101264121414387214674580668057455509143075012733621306263160" + precomputedLines[1][3].A0 = "11808502682088039880722509762195472504203773853102286471310387299634983381936302705137647223160" + precomputedLines[1][3].A1 = "26538581612425607175835690631472952433905434434768278072266019143920745630871659432221060279695" + precomputedLines[2][3].A0 = "27606916857442401348764661255657194264503630672912498851881092896368366375983876434079099099874" + precomputedLines[2][3].A1 = "37956826298854942678069311886424501685374365961755063685897878146156569601064782435322324155014" + precomputedLines[3][3].A0 = "235301833307671160789456884764336659360776113012976206463039289051674884923378757734620155387" + precomputedLines[3][3].A1 = "22676518637187712379266280864503689487888952646012078752202444933765414477121820653980407402223" + precomputedLines[0][2].A0 = "8946296651087559201328866959697194061842930633555895807753409047564077241609164640172744325591" + precomputedLines[0][2].A1 = "3319607416658750987480328291257530998644880938665698099988153697093848972477741037705985373413" + precomputedLines[1][2].A0 = "21115088988943258887371918917298048296112016085508013659929088912794658894435312997633938718848" + precomputedLines[1][2].A1 = "15403196945183287022105606876235579784619138796584603549084162428427901257866238967108717156830" + precomputedLines[2][2].A0 = "30086421369964614829761112460189477561385430103013431226789288615711640049604859752728577463661" + precomputedLines[2][2].A1 = "32494472004796215644376401761829654272752028908154832006541628371298896457302251005793013624319" + precomputedLines[3][2].A0 = "3483805687361374349482384485194080792410631243665507602941560487698743995560939645834652333362" + precomputedLines[3][2].A1 = "12496566982713423909986875509720764905742865192367287712443457586797454850685596047185671329668" + precomputedLines[0][1].A0 = "6584966133329639236916288510823692756333523373657595559980423047785801835150451163398206709164" + precomputedLines[0][1].A1 = "23087541058091714056100440776570484163614800443645247394473680828159748175004218845384987394170" + precomputedLines[1][1].A0 = "33598335456381242914709689023405401247972031498650092597658640301337693353697557725875560825878" + precomputedLines[1][1].A1 = "28808356477788583273585533281656056813283300122087581181854003271434390538943992540594868508436" + precomputedLines[2][1].A0 = "5369558223971857359192375700693164050775171433842412048790430850709296904506550410374189272093" + precomputedLines[2][1].A1 = "12570776851160974101715301489141282986792606850267907619241328889643920000513411014075528661720" + precomputedLines[3][1].A0 = "7102613089855881008880031016512857177075020531235177335802273651321847018710400403953443025043" + precomputedLines[3][1].A1 = "26412315373791493086632742270151451815796411687352927504173275053874594996584272622703049539109" + precomputedLines[0][0].A0 = "26939917288618287837435994497286754734319814047381202278445699245198492659140682604803928442" + precomputedLines[0][0].A1 = "36307083451757752370037493568154867296089840041797721165157682443136574289533042102378682907333" + precomputedLines[1][0].A0 = "30134411929374439282070757093202226487683672560009996546375344388413283785195631245882409928480" + precomputedLines[1][0].A1 = "8190546611799117959040813309423929226298240331266083167332059473930415104488470664015468274944" + precomputedLines[2][0].A0 = "35968217177741653586235067123341278643055805540482803270425032142778796968110701453189054475679" + precomputedLines[2][0].A1 = "2378547078261756302068098906777684902427360760815197472783487789905541860696063882232879376698" + precomputedLines[3][0].A0 = "1311834620354653701125926150934537397252645697980987940429971549693082775464693242860287168453" + precomputedLines[3][0].A1 = "9858119616175232679471599562772881660625635766752192897305265187511481511415034406418578459871" + precomputedLines[4][0].A0 = "25213285515565765134884371308766134134092902165298087726481160116123780862735396408526245177234" + precomputedLines[4][0].A1 = "37555564900963723190977084523023057246345584683203092460041165381572484156341281098606773525380" + precomputedLines[5][0].A0 = "18562892436190599137230611018470635636139728772085915477592967448302772358263615993321607318344" + precomputedLines[5][0].A1 = "33335315303478291126554115666534790699905798851095690303139759682572549302553767095371061145590" + precomputedLines[6][0].A0 = "26480435403576272788049023501152121730463556514121554029731688153525076573621971275652034890185" + precomputedLines[6][0].A1 = "2092411457452315499459472232708795745151480576834991904202013360559686875332995307870727262937" + precomputedLines[7][0].A0 = "4767331024262154459508151143669605262339023811239095074946039395631163193503827493846823803450" + precomputedLines[7][0].A1 = "20273645137181877747181784868269095229533277679129595131465597069022573781875500624218871490886" +} diff --git a/std/algebra/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go similarity index 100% rename from std/algebra/twistededwards/curve.go rename to std/algebra/native/twistededwards/curve.go diff --git a/std/algebra/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go similarity index 100% rename from std/algebra/twistededwards/curve_test.go rename to std/algebra/native/twistededwards/curve_test.go diff --git a/std/algebra/native/twistededwards/doc.go b/std/algebra/native/twistededwards/doc.go new file mode 100644 index 0000000000..95b1486681 --- /dev/null +++ b/std/algebra/native/twistededwards/doc.go @@ -0,0 +1,8 @@ +// Package twistededwards implements the arithmetic of twisted Edwards curves +// in native fields. This uses associated twisted Edwards curves defined over +// the scalar field of the SNARK curves. +// +// Examples: +// Jubjub, Bandersnatch (a twisted Edwards) is defined over BLS12-381's scalar field +// Baby-Jubjub (a twisted Edwards) is defined over BN254's salar fields +package twistededwards diff --git a/std/algebra/twistededwards/point.go b/std/algebra/native/twistededwards/point.go similarity index 100% rename from std/algebra/twistededwards/point.go rename to std/algebra/native/twistededwards/point.go diff --git a/std/algebra/twistededwards/scalarmul_glv.go b/std/algebra/native/twistededwards/scalarmul_glv.go similarity index 98% rename from std/algebra/twistededwards/scalarmul_glv.go rename to std/algebra/native/twistededwards/scalarmul_glv.go index b429d866ce..7b959a2db4 100644 --- a/std/algebra/twistededwards/scalarmul_glv.go +++ b/std/algebra/native/twistededwards/scalarmul_glv.go @@ -22,7 +22,7 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -82,7 +82,7 @@ var DecomposeScalar = func(scalarField *big.Int, inputs []*big.Int, res []*big.I } func init() { - hint.Register(DecomposeScalar) + solver.RegisterHint(DecomposeScalar) } // ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve diff --git a/std/algebra/twistededwards/twistededwards.go b/std/algebra/native/twistededwards/twistededwards.go similarity index 100% rename from std/algebra/twistededwards/twistededwards.go rename to std/algebra/native/twistededwards/twistededwards.go diff --git a/std/algebra/sw_bls12377/doc.go b/std/algebra/sw_bls12377/doc.go deleted file mode 100644 index c9270127e8..0000000000 --- a/std/algebra/sw_bls12377/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package sw (short weierstrass) -package sw_bls12377 diff --git a/std/algebra/sw_bls12377/pairing.go b/std/algebra/sw_bls12377/pairing.go deleted file mode 100644 index 7ca86131e7..0000000000 --- a/std/algebra/sw_bls12377/pairing.go +++ /dev/null @@ -1,243 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sw_bls12377 - -import ( - "errors" - "math/big" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls12377" -) - -// GT target group of the pairing -type GT = fields_bls12377.E12 - -const ateLoop = 9586122913090633729 - -// LineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) -type LineEvaluation struct { - R0, R1 fields_bls12377.E2 -} - -// MillerLoop computes the product of n miller loops (n can be 1) -func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - // check input size match - n := len(P) - if n == 0 || n != len(Q) { - return GT{}, errors.New("invalid inputs sizes") - } - - var ateLoopBin [64]uint - var ateLoopBigInt big.Int - ateLoopBigInt.SetUint64(ateLoop) - for i := 0; i < 64; i++ { - ateLoopBin[i] = ateLoopBigInt.Bit(i) - } - - var res GT - res.SetOne() - - var l1, l2 LineEvaluation - Qacc := make([]G2Affine, n) - yInv := make([]frontend.Variable, n) - xOverY := make([]frontend.Variable, n) - for k := 0; k < n; k++ { - Qacc[k] = Q[k] - yInv[k] = api.DivUnchecked(1, P[k].Y) - xOverY[k] = api.DivUnchecked(P[k].X, P[k].Y) - } - - // k = 0 - Qacc[0], l1 = DoubleStep(api, &Qacc[0]) - res.C1.B0.MulByFp(api, l1.R0, xOverY[0]) - res.C1.B1.MulByFp(api, l1.R1, yInv[0]) - - if n >= 2 { - // k = 1 - Qacc[1], l1 = DoubleStep(api, &Qacc[1]) - l1.R0.MulByFp(api, l1.R0, xOverY[1]) - l1.R1.MulByFp(api, l1.R1, yInv[1]) - res.Mul034By034(api, l1.R0, l1.R1, res.C1.B0, res.C1.B1) - } - - if n >= 3 { - // k >= 2 - for k := 2; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - } - - for i := len(ateLoopBin) - 3; i >= 0; i-- { - res.Square(api, res) - - if ateLoopBin[i] == 0 { - for k := 0; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - continue - } - - for k := 0; k < n; k++ { - Qacc[k], l1, l2 = DoubleAndAddStep(api, &Qacc[k], &Q[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - l2.R0.MulByFp(api, l2.R0, xOverY[k]) - l2.R1.MulByFp(api, l2.R1, yInv[k]) - res.MulBy034(api, l2.R0, l2.R1) - } - } - - return res, nil -} - -// FinalExponentiation computes the final expo x**(p**6-1)(p**2+1)(p**4 - p**2 +1)/r -func FinalExponentiation(api frontend.API, e1 GT) GT { - const genT = ateLoop - - result := e1 - - // https://eprint.iacr.org/2016/130.pdf - var t [3]GT - - // easy part - t[0].Conjugate(api, result) - t[0].DivUnchecked(api, t[0], result) - result.FrobeniusSquare(api, t[0]). - Mul(api, result, t[0]) - - // hard part (up to permutation) - // Daiki Hayashida and Kenichiro Hayasaka - // and Tadanori Teruya - // https://eprint.iacr.org/2020/875.pdf - t[0].CyclotomicSquare(api, result) - t[1].Expt(api, result, genT) - t[2].Conjugate(api, result) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Conjugate(api, t[1]) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Frobenius(api, t[1]) - t[1].Mul(api, t[1], t[2]) - result.Mul(api, result, t[0]) - t[0].Expt(api, t[1], genT) - t[2].Expt(api, t[0], genT) - t[0].FrobeniusSquare(api, t[1]) - t[1].Conjugate(api, t[1]) - t[1].Mul(api, t[1], t[2]) - t[1].Mul(api, t[1], t[0]) - result.Mul(api, result, t[1]) - - return result -} - -// Pair calculates the reduced pairing for a set of points -func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - f, err := MillerLoop(api, P, Q) - if err != nil { - return GT{}, err - } - return FinalExponentiation(api, f), nil -} - -// DoubleAndAddStep -func DoubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, LineEvaluation, LineEvaluation) { - - var n, d, l1, l2, x3, x4, y4 fields_bls12377.E2 - var line1, line2 LineEvaluation - var p G2Affine - - // compute lambda1 = (y2-y1)/(x2-x1) - n.Sub(api, p1.Y, p2.Y) - d.Sub(api, p1.X, p2.X) - l1.DivUnchecked(api, n, d) - - // x3 =lambda1**2-p1.x-p2.x - x3.Square(api, l1). - Sub(api, x3, p1.X). - Sub(api, x3, p2.X) - - // omit y3 computation - - // compute line1 - line1.R0.Neg(api, l1) - line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) - - // compute lambda2 = -lambda1-2*y1/(x3-x1) - n.Double(api, p1.Y) - d.Sub(api, x3, p1.X) - l2.DivUnchecked(api, n, d) - l2.Add(api, l2, l1).Neg(api, l2) - - // compute x4 = lambda2**2-x1-x3 - x4.Square(api, l2). - Sub(api, x4, p1.X). - Sub(api, x4, x3) - - // compute y4 = lambda2*(x1 - x4)-y1 - y4.Sub(api, p1.X, x4). - Mul(api, l2, y4). - Sub(api, y4, p1.Y) - - p.X = x4 - p.Y = y4 - - // compute line2 - line2.R0.Neg(api, l2) - line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) - - return p, line1, line2 -} - -func DoubleStep(api frontend.API, p1 *G2Affine) (G2Affine, LineEvaluation) { - - var n, d, l, xr, yr fields_bls12377.E2 - var p G2Affine - var line LineEvaluation - - // lambda = 3*p1.x**2/2*p.y - n.Square(api, p1.X).MulByFp(api, n, 3) - d.MulByFp(api, p1.Y, 2) - l.DivUnchecked(api, n, d) - - // xr = lambda**2-2*p1.x - xr.Square(api, l). - Sub(api, xr, p1.X). - Sub(api, xr, p1.X) - - // yr = lambda*(p.x-xr)-p.y - yr.Sub(api, p1.X, xr). - Mul(api, l, yr). - Sub(api, yr, p1.Y) - - p.X = xr - p.Y = yr - - line.R0.Neg(api, l) - line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) - - return p, line - -} diff --git a/std/algebra/sw_bls24315/doc.go b/std/algebra/sw_bls24315/doc.go deleted file mode 100644 index 279e09879b..0000000000 --- a/std/algebra/sw_bls24315/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package sw (short weierstrass) -package sw_bls24315 diff --git a/std/algebra/sw_bls24315/pairing.go b/std/algebra/sw_bls24315/pairing.go deleted file mode 100644 index b1e1fb11e0..0000000000 --- a/std/algebra/sw_bls24315/pairing.go +++ /dev/null @@ -1,259 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sw_bls24315 - -import ( - "errors" - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls24315" -) - -// GT target group of the pairing -type GT = fields_bls24315.E24 - -const ateLoop = 3218079743 - -// LineEvaluation represents a sparse Fp12 Elmt (result of the line evaluation) -type LineEvaluation struct { - R0, R1 fields_bls24315.E4 -} - -// MillerLoop computes the product of n miller loops (n can be 1) -func MillerLoop(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - // check input size match - n := len(P) - if n == 0 || n != len(Q) { - return GT{}, errors.New("invalid inputs sizes") - } - - var ateLoop2NAF [33]int8 - ecc.NafDecomposition(big.NewInt(ateLoop), ateLoop2NAF[:]) - - var res GT - res.SetOne() - - var l1, l2 LineEvaluation - Qacc := make([]G2Affine, n) - Qneg := make([]G2Affine, n) - yInv := make([]frontend.Variable, n) - xOverY := make([]frontend.Variable, n) - for k := 0; k < n; k++ { - Qacc[k] = Q[k] - Qneg[k].Neg(api, Q[k]) - yInv[k] = api.DivUnchecked(1, P[k].Y) - xOverY[k] = api.DivUnchecked(P[k].X, P[k].Y) - } - - // k = 0 - Qacc[0], l1 = DoubleStep(api, &Qacc[0]) - res.D1.C0.MulByFp(api, l1.R0, xOverY[0]) - res.D1.C1.MulByFp(api, l1.R1, yInv[0]) - - if n >= 2 { - // k = 1 - Qacc[1], l1 = DoubleStep(api, &Qacc[1]) - l1.R0.MulByFp(api, l1.R0, xOverY[1]) - l1.R1.MulByFp(api, l1.R1, yInv[1]) - res.Mul034By034(api, l1.R0, l1.R1, res.D1.C0, res.D1.C1) - } - - if n >= 3 { - // k >= 2 - for k := 2; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - } - - for i := len(ateLoop2NAF) - 3; i >= 0; i-- { - res.Square(api, res) - - if ateLoop2NAF[i] == 0 { - for k := 0; k < n; k++ { - Qacc[k], l1 = DoubleStep(api, &Qacc[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - } - } else if ateLoop2NAF[i] == 1 { - for k := 0; k < n; k++ { - Qacc[k], l1, l2 = DoubleAndAddStep(api, &Qacc[k], &Q[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - l2.R0.MulByFp(api, l2.R0, xOverY[k]) - l2.R1.MulByFp(api, l2.R1, yInv[k]) - res.MulBy034(api, l2.R0, l2.R1) - } - } else { - for k := 0; k < n; k++ { - Qacc[k], l1, l2 = DoubleAndAddStep(api, &Qacc[k], &Qneg[k]) - l1.R0.MulByFp(api, l1.R0, xOverY[k]) - l1.R1.MulByFp(api, l1.R1, yInv[k]) - res.MulBy034(api, l1.R0, l1.R1) - l2.R0.MulByFp(api, l2.R0, xOverY[k]) - l2.R1.MulByFp(api, l2.R1, yInv[k]) - res.MulBy034(api, l2.R0, l2.R1) - } - } - } - - res.Conjugate(api, res) - - return res, nil -} - -// FinalExponentiation computes the final expo x**(p**12-1)(p**4+1)(p**8 - p**4 +1)/r -func FinalExponentiation(api frontend.API, e1 GT) GT { - const genT = ateLoop - result := e1 - - // https://eprint.iacr.org/2012/232.pdf, section 7 - var t [9]GT - - // easy part - t[0].Conjugate(api, result) - t[0].DivUnchecked(api, t[0], result) - result.FrobeniusQuad(api, t[0]). - Mul(api, result, t[0]) - - // hard part (api, up to permutation) - // Daiki Hayashida and Kenichiro Hayasaka - // and Tadanori Teruya - // https://eprint.iacr.org/2020/875.pdf - // 3*Phi_24(p)/r = (u-1)² * (u+p) * (u²+p²) * (u⁴+p⁴-1) + 3 - t[0].CyclotomicSquare(api, result) - t[1].Expt(api, result, genT) - t[2].Conjugate(api, result) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Conjugate(api, t[1]) - t[1].Mul(api, t[1], t[2]) - t[2].Expt(api, t[1], genT) - t[1].Frobenius(api, t[1]) - t[1].Mul(api, t[1], t[2]) - result.Mul(api, result, t[0]) - t[0].Expt(api, t[1], genT) - t[2].Expt(api, t[0], genT) - t[0].FrobeniusSquare(api, t[1]) - t[2].Mul(api, t[0], t[2]) - t[1].Expt(api, t[2], genT) - t[1].Expt(api, t[1], genT) - t[1].Expt(api, t[1], genT) - t[1].Expt(api, t[1], genT) - t[0].FrobeniusQuad(api, t[2]) - t[0].Mul(api, t[0], t[1]) - t[2].Conjugate(api, t[2]) - t[0].Mul(api, t[0], t[2]) - result.Mul(api, result, t[0]) - - return result -} - -// Pair calculates the reduced pairing for a set of points -func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { - f, err := MillerLoop(api, P, Q) - if err != nil { - return GT{}, err - } - return FinalExponentiation(api, f), nil -} - -// DoubleAndAddStep -func DoubleAndAddStep(api frontend.API, p1, p2 *G2Affine) (G2Affine, LineEvaluation, LineEvaluation) { - - var n, d, l1, l2, x3, x4, y4 fields_bls24315.E4 - var line1, line2 LineEvaluation - var p G2Affine - - // compute lambda1 = (y2-y1)/(x2-x1) - n.Sub(api, p1.Y, p2.Y) - d.Sub(api, p1.X, p2.X) - l1.DivUnchecked(api, n, d) - - // x3 =lambda1**2-p1.x-p2.x - x3.Square(api, l1). - Sub(api, x3, p1.X). - Sub(api, x3, p2.X) - - // omit y3 computation - - // compute line1 - line1.R0.Neg(api, l1) - line1.R1.Mul(api, l1, p1.X).Sub(api, line1.R1, p1.Y) - - // compute lambda2 = -lambda1-2*y1/(x3-x1) - n.Double(api, p1.Y) - d.Sub(api, x3, p1.X) - l2.DivUnchecked(api, n, d) - l2.Add(api, l2, l1).Neg(api, l2) - - // compute x4 = lambda2**2-x1-x3 - x4.Square(api, l2). - Sub(api, x4, p1.X). - Sub(api, x4, x3) - - // compute y4 = lambda2*(x1 - x4)-y1 - y4.Sub(api, p1.X, x4). - Mul(api, l2, y4). - Sub(api, y4, p1.Y) - - p.X = x4 - p.Y = y4 - - // compute line2 - line2.R0.Neg(api, l2) - line2.R1.Mul(api, l2, p1.X).Sub(api, line2.R1, p1.Y) - - return p, line1, line2 -} - -func DoubleStep(api frontend.API, p1 *G2Affine) (G2Affine, LineEvaluation) { - - var n, d, l, xr, yr fields_bls24315.E4 - var p G2Affine - var line LineEvaluation - - // lambda = 3*p1.x**2/2*p.y - n.Square(api, p1.X).MulByFp(api, n, 3) - d.MulByFp(api, p1.Y, 2) - l.DivUnchecked(api, n, d) - - // xr = lambda**2-2*p1.x - xr.Square(api, l). - Sub(api, xr, p1.X). - Sub(api, xr, p1.X) - - // yr = lambda*(p.x-xr)-p.y - yr.Sub(api, p1.X, xr). - Mul(api, l, yr). - Sub(api, yr, p1.Y) - - p.X = xr - p.Y = yr - - line.R0.Neg(api, l) - line.R1.Mul(api, l, p1.X).Sub(api, line.R1, p1.Y) - - return p, line - -} diff --git a/std/algebra/weierstrass/params.go b/std/algebra/weierstrass/params.go deleted file mode 100644 index aadf2bce6a..0000000000 --- a/std/algebra/weierstrass/params.go +++ /dev/null @@ -1,71 +0,0 @@ -package weierstrass - -import ( - "math/big" - - "github.com/consensys/gnark/std/math/emulated" -) - -// CurveParams defines parameters of an elliptic curve in short Weierstrass form -// given by the equation -// -// Y² = X³ + aX + b -// -// The base point is defined by (Gx, Gy). -type CurveParams struct { - A *big.Int // a in curve equation - B *big.Int // b in curve equation - Gx *big.Int // base point x - Gy *big.Int // base point y -} - -// GetSecp256k1Params returns curve parameters for the curve secp256k1. When -// initialising new curve, use the base field [emulated.Secp256k1Fp] and scalar -// field [emulated.Secp256k1Fr]. -func GetSecp256k1Params() CurveParams { - gx, _ := new(big.Int).SetString("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798", 16) - gy, _ := new(big.Int).SetString("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8", 16) - return CurveParams{ - A: big.NewInt(0), - B: big.NewInt(7), - Gx: gx, - Gy: gy, - } -} - -// GetBN254Params returns the curve parameters for the curve BN254 (alt_bn128). -// When initialising new curve, use the base field [emulated.BN254Fp] and scalar -// field [emulated.BN254Fr]. -func GetBN254Params() CurveParams { - gx := big.NewInt(1) - gy := big.NewInt(2) - return CurveParams{ - A: big.NewInt(0), - B: big.NewInt(3), - Gx: gx, - Gy: gy, - } -} - -// GetCurveParams returns suitable curve parameters given the parametric type Base as base field. -func GetCurveParams[Base emulated.FieldParams]() CurveParams { - var t Base - switch t.Modulus().Text(16) { - case "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f": - return secp256k1Params - case "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47": - return bn254Params - default: - panic("no stored parameters") - } -} - -var ( - secp256k1Params CurveParams - bn254Params CurveParams -) - -func init() { - secp256k1Params = GetSecp256k1Params() - bn254Params = GetBN254Params() -} diff --git a/std/algebra/weierstrass/point.go b/std/algebra/weierstrass/point.go deleted file mode 100644 index 8c91ed2761..0000000000 --- a/std/algebra/weierstrass/point.go +++ /dev/null @@ -1,163 +0,0 @@ -package weierstrass - -import ( - "fmt" - "math/big" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/emulated" -) - -// New returns a new [Curve] instance over the base field Base and scalar field -// Scalars defined by the curve parameters params. It returns an error if -// initialising the field emulation fails (for example, when the native field is -// too small) or when the curve parameters are incompatible with the fields. -func New[Base, Scalars emulated.FieldParams](api frontend.API, params CurveParams) (*Curve[Base, Scalars], error) { - ba, err := emulated.NewField[Base](api) - if err != nil { - return nil, fmt.Errorf("new base api: %w", err) - } - sa, err := emulated.NewField[Scalars](api) - if err != nil { - return nil, fmt.Errorf("new scalar api: %w", err) - } - Gx := emulated.ValueOf[Base](params.Gx) - Gy := emulated.ValueOf[Base](params.Gy) - return &Curve[Base, Scalars]{ - params: params, - api: api, - baseApi: ba, - scalarApi: sa, - g: AffinePoint[Base]{ - X: Gx, - Y: Gy, - }, - a: emulated.ValueOf[Base](params.A), - addA: params.A.Cmp(big.NewInt(0)) != 0, - }, nil -} - -// Curve is an initialised curve which allows performing group operations. -type Curve[Base, Scalars emulated.FieldParams] struct { - // params is the parameters of the curve - params CurveParams - // api is the native api, we construct it ourselves to be sure - api frontend.API - // baseApi is the api for point operations - baseApi *emulated.Field[Base] - // scalarApi is the api for scalar operations - scalarApi *emulated.Field[Scalars] - - // g is the generator (base point) of the curve. - g AffinePoint[Base] - - a emulated.Element[Base] - addA bool -} - -// Generator returns the base point of the curve. The method does not copy and -// modifying the returned element leads to undefined behaviour! -func (c *Curve[B, S]) Generator() *AffinePoint[B] { - return &c.g -} - -// AffinePoint represents a point on the elliptic curve. We do not check that -// the point is actually on the curve. -type AffinePoint[Base emulated.FieldParams] struct { - X, Y emulated.Element[Base] -} - -// Neg returns an inverse of p. It doesn't modify p. -func (c *Curve[B, S]) Neg(p *AffinePoint[B]) *AffinePoint[B] { - return &AffinePoint[B]{ - X: p.X, - Y: *c.baseApi.Neg(&p.Y), - } -} - -// AssertIsEqual asserts that p and q are the same point. -func (c *Curve[B, S]) AssertIsEqual(p, q *AffinePoint[B]) { - c.baseApi.AssertIsEqual(&p.X, &q.X) - c.baseApi.AssertIsEqual(&p.Y, &q.Y) -} - -// Add adds q and r and returns it. -func (c *Curve[B, S]) Add(q, r *AffinePoint[B]) *AffinePoint[B] { - // compute lambda = (p1.y-p.y)/(p1.x-p.x) - p1ypy := c.baseApi.Sub(&r.Y, &q.Y) - p1xpx := c.baseApi.Sub(&r.X, &q.X) - lambda := c.baseApi.Div(p1ypy, p1xpx) - - // xr = lambda**2-p.x-p1.x - lambdaSq := c.baseApi.MulMod(lambda, lambda) - qxrx := c.baseApi.Add(&q.X, &r.X) - xr := c.baseApi.Sub(lambdaSq, qxrx) - - // p.y = lambda(p.x-xr) - p.y - pxxr := c.baseApi.Sub(&q.X, xr) - lpxxr := c.baseApi.MulMod(lambda, pxxr) - py := c.baseApi.Sub(lpxxr, &q.Y) - - return &AffinePoint[B]{ - X: *c.baseApi.Reduce(xr), - Y: *c.baseApi.Reduce(py), - } -} - -// Double doubles p and return it. It doesn't modify p. -func (c *Curve[B, S]) Double(p *AffinePoint[B]) *AffinePoint[B] { - - // compute lambda = (3*p1.x**2+a)/2*p1.y, here we assume a=0 (j invariant 0 curve) - xSq3a := c.baseApi.MulMod(&p.X, &p.X) - xSq3a = c.baseApi.MulConst(xSq3a, big.NewInt(3)) - if c.addA { - xSq3a = c.baseApi.Add(xSq3a, &c.a) - } - y2 := c.baseApi.MulConst(&p.Y, big.NewInt(2)) - lambda := c.baseApi.Div(xSq3a, y2) - - // xr = lambda**2-p1.x-p1.x - x2 := c.baseApi.MulConst(&p.X, big.NewInt(2)) - lambdaSq := c.baseApi.MulMod(lambda, lambda) - xr := c.baseApi.Sub(lambdaSq, x2) - - // p.y = lambda(p.x-xr) - p.y - pxxr := c.baseApi.Sub(&p.X, xr) - lpxxr := c.baseApi.MulMod(lambda, pxxr) - py := c.baseApi.Sub(lpxxr, &p.Y) - - return &AffinePoint[B]{ - X: *c.baseApi.Reduce(xr), - Y: *c.baseApi.Reduce(py), - } -} - -// Select selects between p and q given the selector b. If b == 0, then returns -// p and q otherwise. -func (c *Curve[B, S]) Select(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { - x := c.baseApi.Select(b, &p.X, &q.X) - y := c.baseApi.Select(b, &p.Y, &q.Y) - return &AffinePoint[B]{ - X: *x, - Y: *y, - } -} - -// ScalarMul computes s * p and returns it. It doesn't modify p nor s. -func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *AffinePoint[B] { - res := p - acc := c.Double(p) - - var st S - sr := c.scalarApi.Reduce(s) - sBits := c.scalarApi.ToBits(sr) - for i := 1; i < st.Modulus().BitLen(); i++ { - tmp := c.Add(res, acc) - res = c.Select(sBits[i], tmp, res) - acc = c.Double(acc) - } - - tmp := c.Add(res, c.Neg(p)) - res = c.Select(sBits[0], res, tmp) - return res -} diff --git a/std/algebra/weierstrass/point_test.go b/std/algebra/weierstrass/point_test.go deleted file mode 100644 index 28faf42109..0000000000 --- a/std/algebra/weierstrass/point_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package weierstrass - -import ( - "math/big" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/secp256k1" - "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/gnark/test" -) - -var testCurve = ecc.BN254 - -type NegTest[T, S emulated.FieldParams] struct { - P, Q AffinePoint[T] -} - -func (c *NegTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.Neg(&c.P) - cr.AssertIsEqual(res, &c.Q) - return nil -} - -func TestNeg(t *testing.T) { - assert := test.NewAssert(t) - _, g := secp256k1.Generators() - var yn fp.Element - yn.Neg(&g.Y) - circuit := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := NegTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](yn), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) -} - -type AddTest[T, S emulated.FieldParams] struct { - P, Q, R AffinePoint[T] -} - -func (c *AddTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.Add(&c.P, &c.Q) - cr.AssertIsEqual(res, &c.R) - return nil -} - -func TestAdd(t *testing.T) { - assert := test.NewAssert(t) - var dJac, aJac secp256k1.G1Jac - g, _ := secp256k1.Generators() - dJac.Double(&g) - aJac.Set(&dJac). - AddAssign(&g) - var dAff, aAff secp256k1.G1Affine - dAff.FromJacobian(&dJac) - aAff.FromJacobian(&aJac) - circuit := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := AddTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), - }, - R: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](aAff.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](aAff.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) -} - -type DoubleTest[T, S emulated.FieldParams] struct { - P, Q AffinePoint[T] -} - -func (c *DoubleTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.Double(&c.P) - cr.AssertIsEqual(res, &c.Q) - return nil -} - -func TestDouble(t *testing.T) { - assert := test.NewAssert(t) - g, _ := secp256k1.Generators() - var dJac secp256k1.G1Jac - dJac.Double(&g) - var dAff secp256k1.G1Affine - dAff.FromJacobian(&dJac) - circuit := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := DoubleTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](dAff.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](dAff.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) -} - -type ScalarMulTest[T, S emulated.FieldParams] struct { - P, Q AffinePoint[T] - S emulated.Element[S] -} - -func (c *ScalarMulTest[T, S]) Define(api frontend.API) error { - cr, err := New[T, S](api, GetCurveParams[T]()) - if err != nil { - return err - } - res := cr.ScalarMul(&c.P, &c.S) - cr.AssertIsEqual(res, &c.Q) - return nil -} - -func TestScalarMul(t *testing.T) { - assert := test.NewAssert(t) - _, g := secp256k1.Generators() - s, ok := new(big.Int).SetString("44693544921776318736021182399461740191514036429448770306966433218654680512345", 10) - assert.True(ok) - var S secp256k1.G1Affine - S.ScalarMultiplication(&g, s) - - circuit := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} - witness := ScalarMulTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ - S: emulated.ValueOf[emulated.Secp256k1Fr](s), - P: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), - }, - Q: AffinePoint[emulated.Secp256k1Fp]{ - X: emulated.ValueOf[emulated.Secp256k1Fp](S.X), - Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) - _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit) - assert.NoError(err) -} - -func TestScalarMul2(t *testing.T) { - assert := test.NewAssert(t) - s, ok := new(big.Int).SetString("14108069686105661647148607545884343550368786660735262576656400957535521042679", 10) - assert.True(ok) - var res bn254.G1Affine - _, _, gen, _ := bn254.Generators() - res.ScalarMultiplication(&gen, s) - - circuit := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{} - witness := ScalarMulTest[emulated.BN254Fp, emulated.BN254Fr]{ - S: emulated.ValueOf[emulated.BN254Fr](s), - P: AffinePoint[emulated.BN254Fp]{ - X: emulated.ValueOf[emulated.BN254Fp](gen.X), - Y: emulated.ValueOf[emulated.BN254Fp](gen.Y), - }, - Q: AffinePoint[emulated.BN254Fp]{ - X: emulated.ValueOf[emulated.BN254Fp](res.X), - Y: emulated.ValueOf[emulated.BN254Fp](res.Y), - }, - } - err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) - assert.NoError(err) - _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit) - assert.NoError(err) -} diff --git a/std/commitments/fri/fri.go b/std/commitments/fri/fri.go index 96b8ac2c3d..93dd0b301f 100644 --- a/std/commitments/fri/fri.go +++ b/std/commitments/fri/fri.go @@ -50,7 +50,7 @@ type RadixTwoFri struct { // hash function that is used for Fiat Shamir and for committing to // the oracles. - h hash.Hash + h hash.FieldHasher // nbSteps number of interactions between the prover and the verifier nbSteps int @@ -66,7 +66,7 @@ type RadixTwoFri struct { // NewRadixTwoFri creates an FFT-like oracle proof of proximity. // * h is the hash function that is used for the Merkle proofs // * gen is the generator of the cyclic group of unity of size \rho * size -func NewRadixTwoFri(size uint64, h hash.Hash, gen big.Int) RadixTwoFri { +func NewRadixTwoFri(size uint64, h hash.FieldHasher, gen big.Int) RadixTwoFri { var res RadixTwoFri diff --git a/std/commitments/fri/utils.go b/std/commitments/fri/utils.go index 13efcf98d9..24c0716073 100644 --- a/std/commitments/fri/utils.go +++ b/std/commitments/fri/utils.go @@ -3,7 +3,7 @@ package fri import ( "math/big" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -84,5 +84,5 @@ var DeriveQueriesPositions = func(_ *big.Int, inputs []*big.Int, res []*big.Int) } func init() { - hint.Register(DeriveQueriesPositions) + solver.RegisterHint(DeriveQueriesPositions) } diff --git a/std/commitments/kzg_bls12377/verifier.go b/std/commitments/kzg_bls12377/verifier.go index 49f9ed5b19..e7cffd85a3 100644 --- a/std/commitments/kzg_bls12377/verifier.go +++ b/std/commitments/kzg_bls12377/verifier.go @@ -19,8 +19,8 @@ package kzg_bls12377 import ( "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" ) // Digest commitment of a polynomial. @@ -28,7 +28,6 @@ type Digest = sw_bls12377.G1Affine // VK verification key (G2 part of SRS) type VK struct { - G1 sw_bls12377.G1Affine // G₁ G2 [2]sw_bls12377.G2Affine // [G₂, [α]G₂] } @@ -53,7 +52,7 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front // [f(a)]G₁ var claimedValueG1Aff sw_bls12377.G1Affine - claimedValueG1Aff.ScalarMul(api, srs.G1, proof.ClaimedValue) + claimedValueG1Aff.ScalarMulBase(api, proof.ClaimedValue) // [f(α) - f(a)]G₁ var fminusfaG1 sw_bls12377.G1Affine @@ -64,17 +63,16 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front var negH sw_bls12377.G1Affine negH.Neg(api, proof.H) - // [α-a]G₂ - var alphaMinusaG2 sw_bls12377.G2Affine - alphaMinusaG2.ScalarMul(api, srs.G2[0], point). - Neg(api, alphaMinusaG2). - AddAssign(api, srs.G2[1]) + // [f(α) - f(a) + a*H(α)]G₁ + var totalG1 sw_bls12377.G1Affine + totalG1.ScalarMul(api, proof.H, point). + AddAssign(api, fminusfaG1) - // e([f(α) - f(a)]G₁, G₂).e([-H(α)]G₁, [α-a]G₂) ==? 1 + // e([f(α)-f(a)+aH(α)]G₁], G₂).e([-H(α)]G₁, [α]G₂) == 1 resPairing, _ := sw_bls12377.Pair( api, - []sw_bls12377.G1Affine{fminusfaG1, negH}, - []sw_bls12377.G2Affine{srs.G2[0], alphaMinusaG2}, + []sw_bls12377.G1Affine{totalG1, negH}, + []sw_bls12377.G2Affine{srs.G2[0], srs.G2[1]}, ) var one fields_bls12377.E12 diff --git a/std/commitments/kzg_bls12377/verifier_test.go b/std/commitments/kzg_bls12377/verifier_test.go index 55f9c9453a..09b7684b5e 100644 --- a/std/commitments/kzg_bls12377/verifier_test.go +++ b/std/commitments/kzg_bls12377/verifier_test.go @@ -69,17 +69,17 @@ func TestVerifierDynamic(t *testing.T) { } // commit to the polynomial - com, err := kzg.Commit(f, srs) + com, err := kzg.Commit(f, srs.Pk) assert.NoError(err) // create opening proof var point fr.Element point.SetRandom() - proof, err := kzg.Open(f, point, srs) + proof, err := kzg.Open(f, point, srs.Pk) assert.NoError(err) // check that the proof is correct - err = kzg.Verify(&com, &proof, point, srs) + err = kzg.Verify(&com, &proof, point, srs.Vk) if err != nil { t.Fatal(err) } @@ -98,17 +98,14 @@ func TestVerifierDynamic(t *testing.T) { witness.S = point.String() - witness.VerifKey.G1.X = srs.G1[0].X.String() - witness.VerifKey.G1.Y = srs.G1[0].Y.String() - - witness.VerifKey.G2[0].X.A0 = srs.G2[0].X.A0.String() - witness.VerifKey.G2[0].X.A1 = srs.G2[0].X.A1.String() - witness.VerifKey.G2[0].Y.A0 = srs.G2[0].Y.A0.String() - witness.VerifKey.G2[0].Y.A1 = srs.G2[0].Y.A1.String() - witness.VerifKey.G2[1].X.A0 = srs.G2[1].X.A0.String() - witness.VerifKey.G2[1].X.A1 = srs.G2[1].X.A1.String() - witness.VerifKey.G2[1].Y.A0 = srs.G2[1].Y.A0.String() - witness.VerifKey.G2[1].Y.A1 = srs.G2[1].Y.A1.String() + witness.VerifKey.G2[0].X.A0 = srs.Vk.G2[0].X.A0.String() + witness.VerifKey.G2[0].X.A1 = srs.Vk.G2[0].X.A1.String() + witness.VerifKey.G2[0].Y.A0 = srs.Vk.G2[0].Y.A0.String() + witness.VerifKey.G2[0].Y.A1 = srs.Vk.G2[0].Y.A1.String() + witness.VerifKey.G2[1].X.A0 = srs.Vk.G2[1].X.A0.String() + witness.VerifKey.G2[1].X.A1 = srs.Vk.G2[1].X.A1.String() + witness.VerifKey.G2[1].Y.A0 = srs.Vk.G2[1].Y.A0.String() + witness.VerifKey.G2[1].Y.A1 = srs.Vk.G2[1].Y.A1.String() // check if the circuit is solved var circuit verifierCircuit @@ -132,8 +129,6 @@ func TestVerifier(t *testing.T) { witness.Proof.ClaimedValue = "7211341386127354417397285211336133449231039596179023429378585109196698597268" witness.S = "4321" - witness.VerifKey.G1.X = "81937999373150964239938255573465948239988671502647976594219695644855304257327692006745978603320413799295628339695" - witness.VerifKey.G1.Y = "241266749859715473739788878240585681733927191168601896383759122102112907357779751001206799952863815012735208165030" witness.VerifKey.G2[0].X.A0 = "233578398248691099356572568220835526895379068987715365179118596935057653620464273615301663571204657964920925606294" witness.VerifKey.G2[0].X.A1 = "140913150380207355837477652521042157274541796891053068589147167627541651775299824604154852141315666357241556069118" witness.VerifKey.G2[0].Y.A0 = "63160294768292073209381361943935198908131692476676907196754037919244929611450776219210369229519898517858833747423" diff --git a/std/commitments/kzg_bls24315/verifier.go b/std/commitments/kzg_bls24315/verifier.go index 4d7625e2e9..260064b691 100644 --- a/std/commitments/kzg_bls24315/verifier.go +++ b/std/commitments/kzg_bls24315/verifier.go @@ -19,8 +19,8 @@ package kzg_bls24315 import ( "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/fields_bls24315" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" ) // Digest commitment of a polynomial. @@ -28,7 +28,6 @@ type Digest = sw_bls24315.G1Affine // VK verification key (G2 part of SRS) type VK struct { - G1 sw_bls24315.G1Affine // G₁ G2 [2]sw_bls24315.G2Affine // [G₂, [α]G₂] } @@ -53,7 +52,7 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front // [f(a)]G₁ var claimedValueG1Aff sw_bls24315.G1Affine - claimedValueG1Aff.ScalarMul(api, srs.G1, proof.ClaimedValue) + claimedValueG1Aff.ScalarMulBase(api, proof.ClaimedValue) // [f(α) - f(a)]G₁ var fminusfaG1 sw_bls24315.G1Affine @@ -64,17 +63,16 @@ func Verify(api frontend.API, commitment Digest, proof OpeningProof, point front var negH sw_bls24315.G1Affine negH.Neg(api, proof.H) - // [α-a]G₂ - var alphaMinusaG2 sw_bls24315.G2Affine - alphaMinusaG2.ScalarMul(api, srs.G2[0], point). - Neg(api, alphaMinusaG2). - AddAssign(api, srs.G2[1]) + // [f(α) - f(a) + a*H(α)]G₁ + var totalG1 sw_bls24315.G1Affine + totalG1.ScalarMul(api, proof.H, point). + AddAssign(api, fminusfaG1) - // e([f(α) - f(a)]G₁, G₂).e([-H(α)]G₁, [α-a]G₂) ==? 1 + // e([f(α)-f(a)+aH(α)]G₁], G₂).e([-H(α)]G₁, [α]G₂) == 1 resPairing, _ := sw_bls24315.Pair( api, - []sw_bls24315.G1Affine{fminusfaG1, negH}, - []sw_bls24315.G2Affine{srs.G2[0], alphaMinusaG2}, + []sw_bls24315.G1Affine{totalG1, negH}, + []sw_bls24315.G2Affine{srs.G2[0], srs.G2[1]}, ) var one fields_bls24315.E24 diff --git a/std/commitments/kzg_bls24315/verifier_test.go b/std/commitments/kzg_bls24315/verifier_test.go index d331ce82ed..d2588cacee 100644 --- a/std/commitments/kzg_bls24315/verifier_test.go +++ b/std/commitments/kzg_bls24315/verifier_test.go @@ -69,17 +69,17 @@ func TestVerifierDynamic(t *testing.T) { } // commit to the polynomial - com, err := kzg.Commit(f, srs) + com, err := kzg.Commit(f, srs.Pk) assert.NoError(err) // create opening proof var point fr.Element point.SetRandom() - proof, err := kzg.Open(f, point, srs) + proof, err := kzg.Open(f, point, srs.Pk) assert.NoError(err) // check that the proof is correct - err = kzg.Verify(&com, &proof, point, srs) + err = kzg.Verify(&com, &proof, point, srs.Vk) if err != nil { t.Fatal(err) } @@ -98,26 +98,23 @@ func TestVerifierDynamic(t *testing.T) { witness.S = point.String() - witness.VerifKey.G1.X = srs.G1[0].X.String() - witness.VerifKey.G1.Y = srs.G1[0].Y.String() - - witness.VerifKey.G2[0].X.B0.A0 = srs.G2[0].X.B0.A0.String() - witness.VerifKey.G2[0].X.B0.A1 = srs.G2[0].X.B0.A1.String() - witness.VerifKey.G2[0].X.B1.A0 = srs.G2[0].X.B1.A0.String() - witness.VerifKey.G2[0].X.B1.A1 = srs.G2[0].X.B1.A1.String() - witness.VerifKey.G2[0].Y.B0.A0 = srs.G2[0].Y.B0.A0.String() - witness.VerifKey.G2[0].Y.B0.A1 = srs.G2[0].Y.B0.A1.String() - witness.VerifKey.G2[0].Y.B1.A0 = srs.G2[0].Y.B1.A0.String() - witness.VerifKey.G2[0].Y.B1.A1 = srs.G2[0].Y.B1.A1.String() - - witness.VerifKey.G2[1].X.B0.A0 = srs.G2[1].X.B0.A0.String() - witness.VerifKey.G2[1].X.B0.A1 = srs.G2[1].X.B0.A1.String() - witness.VerifKey.G2[1].X.B1.A0 = srs.G2[1].X.B1.A0.String() - witness.VerifKey.G2[1].X.B1.A1 = srs.G2[1].X.B1.A1.String() - witness.VerifKey.G2[1].Y.B0.A0 = srs.G2[1].Y.B0.A0.String() - witness.VerifKey.G2[1].Y.B0.A1 = srs.G2[1].Y.B0.A1.String() - witness.VerifKey.G2[1].Y.B1.A0 = srs.G2[1].Y.B1.A0.String() - witness.VerifKey.G2[1].Y.B1.A1 = srs.G2[1].Y.B1.A1.String() + witness.VerifKey.G2[0].X.B0.A0 = srs.Vk.G2[0].X.B0.A0.String() + witness.VerifKey.G2[0].X.B0.A1 = srs.Vk.G2[0].X.B0.A1.String() + witness.VerifKey.G2[0].X.B1.A0 = srs.Vk.G2[0].X.B1.A0.String() + witness.VerifKey.G2[0].X.B1.A1 = srs.Vk.G2[0].X.B1.A1.String() + witness.VerifKey.G2[0].Y.B0.A0 = srs.Vk.G2[0].Y.B0.A0.String() + witness.VerifKey.G2[0].Y.B0.A1 = srs.Vk.G2[0].Y.B0.A1.String() + witness.VerifKey.G2[0].Y.B1.A0 = srs.Vk.G2[0].Y.B1.A0.String() + witness.VerifKey.G2[0].Y.B1.A1 = srs.Vk.G2[0].Y.B1.A1.String() + + witness.VerifKey.G2[1].X.B0.A0 = srs.Vk.G2[1].X.B0.A0.String() + witness.VerifKey.G2[1].X.B0.A1 = srs.Vk.G2[1].X.B0.A1.String() + witness.VerifKey.G2[1].X.B1.A0 = srs.Vk.G2[1].X.B1.A0.String() + witness.VerifKey.G2[1].X.B1.A1 = srs.Vk.G2[1].X.B1.A1.String() + witness.VerifKey.G2[1].Y.B0.A0 = srs.Vk.G2[1].Y.B0.A0.String() + witness.VerifKey.G2[1].Y.B0.A1 = srs.Vk.G2[1].Y.B0.A1.String() + witness.VerifKey.G2[1].Y.B1.A0 = srs.Vk.G2[1].Y.B1.A0.String() + witness.VerifKey.G2[1].Y.B1.A1 = srs.Vk.G2[1].Y.B1.A1.String() // check if the circuit is solved var circuit verifierCircuit @@ -141,8 +138,6 @@ func TestVerifier(t *testing.T) { witness.Proof.ClaimedValue = "10347231107172233075459792371577505115223937655290126532055162077965558980163" witness.S = "4321" - witness.VerifKey.G1.X = "34223510504517033132712852754388476272837911830964394866541204856091481856889569724484362330263" - witness.VerifKey.G1.Y = "24215295174889464585413596429561903295150472552154479431771837786124301185073987899223459122783" witness.VerifKey.G2[0].X.B0.A0 = "24614737899199071964341749845083777103809664018538138889239909664991294445469052467064654073699" witness.VerifKey.G2[0].X.B0.A1 = "17049297748993841127032249156255993089778266476087413538366212660716380683149731996715975282972" diff --git a/std/evmprecompiles/01-ecrecover.go b/std/evmprecompiles/01-ecrecover.go new file mode 100644 index 0000000000..2e7dea01f3 --- /dev/null +++ b/std/evmprecompiles/01-ecrecover.go @@ -0,0 +1,86 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/emulated" +) + +// ECRecover implements [ECRECOVER] precompile contract at address 0x01. +// +// [ECRECOVER]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/ecrecover/index.html +func ECRecover(api frontend.API, msg emulated.Element[emulated.Secp256k1Fr], + v frontend.Variable, r, s emulated.Element[emulated.Secp256k1Fr], + strictRange frontend.Variable) *sw_emulated.AffinePoint[emulated.Secp256k1Fp] { + // EVM uses v \in {27, 28}, but everyone else v >= 0. Convert back + v = api.Sub(v, 27) + var emfp emulated.Secp256k1Fp + var emfr emulated.Secp256k1Fr + fpField, err := emulated.NewField[emulated.Secp256k1Fp](api) + if err != nil { + panic(fmt.Sprintf("new field: %v", err)) + } + frField, err := emulated.NewField[emulated.Secp256k1Fr](api) + if err != nil { + panic(fmt.Sprintf("new field: %v", err)) + } + // with the encoding we may have that r,s < 2*Fr (i.e. not r,s < Fr). Apply more thorough checks. + frField.AssertIsLessOrEqual(&r, frField.Modulus()) + // Ethereum Yellow Paper defines that the check for s should be more strict + // when checking transaction signatures (Appendix F). There we should check + // that s <= (Fr-1)/2 + halfFr := new(big.Int).Sub(emfr.Modulus(), big.NewInt(1)) + halfFr.Div(halfFr, big.NewInt(2)) + bound := frField.Select(strictRange, frField.NewElement(halfFr), frField.Modulus()) + frField.AssertIsLessOrEqual(&s, bound) + + curve, err := sw_emulated.New[emulated.Secp256k1Fp, emulated.Secp256k1Fr](api, sw_emulated.GetSecp256k1Params()) + if err != nil { + panic(fmt.Sprintf("new curve: %v", err)) + } + // we cannot directly use the field emulation hint calling wrappers as we work between two fields. + Rlimbs, err := api.Compiler().NewHint(recoverPointHint, 2*int(emfp.NbLimbs()), recoverPointHintArgs(v, r)...) + if err != nil { + panic(fmt.Sprintf("point hint: %v", err)) + } + R := sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: *fpField.NewElement(Rlimbs[0:emfp.NbLimbs()]), + Y: *fpField.NewElement(Rlimbs[emfp.NbLimbs() : 2*emfp.NbLimbs()]), + } + // we cannot directly use the field emulation hint calling wrappers as we work between two fields. + Plimbs, err := api.Compiler().NewHint(recoverPublicKeyHint, 2*int(emfp.NbLimbs()), recoverPublicKeyHintArgs(msg, v, r, s)...) + if err != nil { + panic(fmt.Sprintf("point hint: %v", err)) + } + P := sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: *fpField.NewElement(Plimbs[0:emfp.NbLimbs()]), + Y: *fpField.NewElement(Plimbs[emfp.NbLimbs() : 2*emfp.NbLimbs()]), + } + // check that len(v) = 2 + vbits := bits.ToBinary(api, v, bits.WithNbDigits(2)) + // check that Rx is correct: x = r+v[1]*fr + tmp := fpField.Select(vbits[1], fpField.NewElement(emfr.Modulus()), fpField.NewElement(0)) + rbits := frField.ToBits(&r) + rfp := fpField.FromBits(rbits...) + tmp = fpField.Add(rfp, tmp) + fpField.AssertIsEqual(tmp, &R.X) + // check that Ry is correct: oddity(y) = v[0] + Rynormal := fpField.Reduce(&R.Y) + Rybits := fpField.ToBits(Rynormal) + api.AssertIsEqual(vbits[0], Rybits[0]) + // compute rinv = r^{-1} mod fr + rinv := frField.Inverse(&r) + // compute u1 = -msg * rinv + u1 := frField.MulMod(&msg, rinv) + u1 = frField.Neg(u1) + // compute u2 = s * rinv + u2 := frField.MulMod(&s, rinv) + // check u1 * G + u2 R == P + C := curve.JointScalarMulBase(&R, u2, u1) + curve.AssertIsEqual(C, &P) + return &P +} diff --git a/std/evmprecompiles/01-ecrecover_test.go b/std/evmprecompiles/01-ecrecover_test.go new file mode 100644 index 0000000000..577a45b75b --- /dev/null +++ b/std/evmprecompiles/01-ecrecover_test.go @@ -0,0 +1,141 @@ +package evmprecompiles + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1/ecdsa" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +func TestSignForRecoverCorrectness(t *testing.T) { + sk, err := ecdsa.GenerateKey(rand.Reader) + if err != nil { + t.Fatal("generate", err) + } + pk := sk.PublicKey + msg := []byte("test") + _, r, s, err := sk.SignForRecover(msg, nil) + if err != nil { + t.Fatal("sign", err) + } + var sig ecdsa.Signature + r.FillBytes(sig.R[:fr.Bytes]) + s.FillBytes(sig.S[:fr.Bytes]) + sigM := sig.Bytes() + ok, err := pk.Verify(sigM, msg, nil) + if err != nil { + t.Fatal("verify", err) + } + if !ok { + t.Fatal("not verified") + } +} + +type ecrecoverCircuit struct { + Message emulated.Element[emulated.Secp256k1Fr] + V frontend.Variable + R emulated.Element[emulated.Secp256k1Fr] + S emulated.Element[emulated.Secp256k1Fr] + Strict frontend.Variable + Expected sw_emulated.AffinePoint[emulated.Secp256k1Fp] +} + +func (c *ecrecoverCircuit) Define(api frontend.API) error { + curve, err := sw_emulated.New[emulated.Secp256k1Fp, emulated.Secp256k1Fr](api, sw_emulated.GetSecp256k1Params()) + if err != nil { + return fmt.Errorf("new curve: %w", err) + } + res := ECRecover(api, c.Message, c.V, c.R, c.S, c.Strict) + curve.AssertIsEqual(&c.Expected, res) + return nil +} + +func testRoutineECRecover(t *testing.T, wantStrict bool) (circ, wit *ecrecoverCircuit, largeS bool) { + halfFr := new(big.Int).Sub(fr.Modulus(), big.NewInt(1)) + halfFr.Div(halfFr, big.NewInt(2)) + + sk, err := ecdsa.GenerateKey(rand.Reader) + if err != nil { + t.Fatal("generate", err) + } + pk := sk.PublicKey + msg := []byte("test") + var r, s *big.Int + var v uint + for { + v, r, s, err = sk.SignForRecover(msg, nil) + if err != nil { + t.Fatal("sign", err) + } + if !wantStrict || halfFr.Cmp(s) > 0 { + break + } + } + strict := 0 + if wantStrict { + strict = 1 + } + circuit := ecrecoverCircuit{} + witness := ecrecoverCircuit{ + Message: emulated.ValueOf[emulated.Secp256k1Fr](ecdsa.HashToInt(msg)), + V: v + 27, // EVM constant + R: emulated.ValueOf[emulated.Secp256k1Fr](r), + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Strict: strict, + Expected: sw_emulated.AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.Y), + }, + } + return &circuit, &witness, halfFr.Cmp(s) <= 0 +} + +func TestECRecoverCircuitShortStrict(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness, _ := testRoutineECRecover(t, true) + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECRecoverCircuitShortLax(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness, _ := testRoutineECRecover(t, false) + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECRecoverCircuitShortMismatch(t *testing.T) { + assert := test.NewAssert(t) + halfFr := new(big.Int).Sub(fr.Modulus(), big.NewInt(1)) + halfFr.Div(halfFr, big.NewInt(2)) + var circuit, witness *ecrecoverCircuit + var largeS bool + for { + circuit, witness, largeS = testRoutineECRecover(t, false) + if largeS { + witness.Strict = 1 + break + } + } + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.Error(err) +} + +func TestECRecoverCircuitFull(t *testing.T) { + t.Skip("skipping very long test") + assert := test.NewAssert(t) + circuit, witness, _ := testRoutineECRecover(t, false) + assert.ProverSucceeded(circuit, witness, + test.NoFuzzing(), test.NoSerialization(), + test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BN254), + ) +} diff --git a/std/evmprecompiles/02-sha256.go b/std/evmprecompiles/02-sha256.go new file mode 100644 index 0000000000..ae0c33e733 --- /dev/null +++ b/std/evmprecompiles/02-sha256.go @@ -0,0 +1,3 @@ +package evmprecompiles + +// will implement later diff --git a/std/evmprecompiles/04-id.go b/std/evmprecompiles/04-id.go new file mode 100644 index 0000000000..d554d428c8 --- /dev/null +++ b/std/evmprecompiles/04-id.go @@ -0,0 +1,3 @@ +package evmprecompiles + +// not going to implement. It is trivial. diff --git a/std/evmprecompiles/05-expmod.go b/std/evmprecompiles/05-expmod.go new file mode 100644 index 0000000000..6b1eb16123 --- /dev/null +++ b/std/evmprecompiles/05-expmod.go @@ -0,0 +1 @@ +package evmprecompiles diff --git a/std/evmprecompiles/06-bnadd.go b/std/evmprecompiles/06-bnadd.go new file mode 100644 index 0000000000..ff6c397fc9 --- /dev/null +++ b/std/evmprecompiles/06-bnadd.go @@ -0,0 +1,21 @@ +package evmprecompiles + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +// ECAdd implements [ALT_BN128_ADD] precompile contract at address 0x06. +// +// [ALT_BN128_ADD]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-add +func ECAdd(api frontend.API, P, Q *sw_emulated.AffinePoint[emulated.BN254Fp]) *sw_emulated.AffinePoint[emulated.BN254Fp] { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + panic(err) + } + // Check that P and Q are on the curve (done in the zkEVM ⚠️ ) + // We use AddUnified because P can be equal to Q, -Q and either or both can be (0,0) + res := curve.AddUnified(P, Q) + return res +} diff --git a/std/evmprecompiles/07-bnmul.go b/std/evmprecompiles/07-bnmul.go new file mode 100644 index 0000000000..eb1c02ccda --- /dev/null +++ b/std/evmprecompiles/07-bnmul.go @@ -0,0 +1,20 @@ +package evmprecompiles + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" +) + +// ECMul implements [ALT_BN128_MUL] precompile contract at address 0x07. +// +// [ALT_BN128_MUL]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-mul +func ECMul(api frontend.API, P *sw_emulated.AffinePoint[emulated.BN254Fp], u *emulated.Element[emulated.BN254Fr]) *sw_emulated.AffinePoint[emulated.BN254Fp] { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + panic(err) + } + // Check that P is on the curve (done in the zkEVM ⚠️ ) + res := curve.ScalarMul(P, u) + return res +} diff --git a/std/evmprecompiles/08-bnpairing.go b/std/evmprecompiles/08-bnpairing.go new file mode 100644 index 0000000000..9d25cd20ec --- /dev/null +++ b/std/evmprecompiles/08-bnpairing.go @@ -0,0 +1,50 @@ +package evmprecompiles + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" +) + +// ECPair implements [ALT_BN128_PAIRING_CHECK] precompile contract at address 0x08. +// +// [ALT_BN128_PAIRING_CHECK]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-pairing-check +// +// To have a fixed-circuit regardless of the number of inputs, we need 2 fixed circuits: +// - A Miller loop of fixed size 1 followed with a multiplication in 𝔽p¹² (MillerLoopAndMul) +// - A final exponentiation followed with an equality check in GT (FinalExponentiationIsOne) +// +// N.B.: This is a sub-optimal routine but defines a fixed circuit regardless +// of the number of inputs. We can extend this routine to handle a 2-by-2 +// logic but we prefer a minimal number of circuits (2). + +func ECPair(api frontend.API, P []*sw_bn254.G1Affine, Q []*sw_bn254.G2Affine) { + if len(P) != len(Q) { + panic("P and Q length mismatch") + } + if len(P) < 2 { + panic("invalid multipairing size bound") + } + n := len(P) + pair, err := sw_bn254.NewPairing(api) + if err != nil { + panic(err) + } + // 1- Check that Pᵢ are on G1 (done in the zkEVM ⚠️ ) + // 2- Check that Qᵢ are on G2 + for i := 0; i < len(Q); i++ { + pair.AssertIsOnG2(Q[i]) + } + + // 3- Check that ∏ᵢ e(Pᵢ, Qᵢ) == 1 + ml := pair.One() + for i := 0; i < n; i++ { + // fixed circuit 1 + ml, err = pair.MillerLoopAndMul(P[i], Q[i], ml) + if err != nil { + panic(err) + } + } + + // fixed circuit 2 + pair.FinalExponentiationIsOne(ml) +} diff --git a/std/evmprecompiles/bn_test.go b/std/evmprecompiles/bn_test.go new file mode 100644 index 0000000000..24cb21c7ba --- /dev/null +++ b/std/evmprecompiles/bn_test.go @@ -0,0 +1,197 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +type ecaddCircuit struct { + X0 sw_emulated.AffinePoint[emulated.BN254Fp] + X1 sw_emulated.AffinePoint[emulated.BN254Fp] + Expected sw_emulated.AffinePoint[emulated.BN254Fp] +} + +func (c *ecaddCircuit) Define(api frontend.API) error { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + return err + } + res := ECAdd(api, &c.X0, &c.X1) + curve.AssertIsEqual(res, &c.Expected) + return nil +} + +func testRoutineECAdd() (circ, wit frontend.Circuit) { + _, _, G, _ := bn254.Generators() + var u, v fr.Element + u.SetRandom() + v.SetRandom() + var P, Q bn254.G1Affine + P.ScalarMultiplication(&G, u.BigInt(new(big.Int))) + Q.ScalarMultiplication(&G, v.BigInt(new(big.Int))) + var expected bn254.G1Affine + expected.Add(&P, &Q) + circuit := ecaddCircuit{} + witness := ecaddCircuit{ + X0: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](P.X), + Y: emulated.ValueOf[emulated.BN254Fp](P.Y), + }, + X1: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](Q.X), + Y: emulated.ValueOf[emulated.BN254Fp](Q.Y), + }, + Expected: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](expected.X), + Y: emulated.ValueOf[emulated.BN254Fp](expected.Y), + }, + } + return &circuit, &witness +} + +func TestECAddCircuitShort(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness := testRoutineECAdd() + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECAddCircuitFull(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness := testRoutineECAdd() + assert.ProverSucceeded(circuit, witness, + test.NoFuzzing(), test.NoSerialization(), + test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BN254), + ) +} + +type ecmulCircuit struct { + X0 sw_emulated.AffinePoint[emulated.BN254Fp] + U emulated.Element[emulated.BN254Fr] + Expected sw_emulated.AffinePoint[emulated.BN254Fp] +} + +func (c *ecmulCircuit) Define(api frontend.API) error { + curve, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetBN254Params()) + if err != nil { + return err + } + res := ECMul(api, &c.X0, &c.U) + curve.AssertIsEqual(res, &c.Expected) + return nil +} + +func testRoutineECMul(t *testing.T) (circ, wit frontend.Circuit) { + _, _, G, _ := bn254.Generators() + var u, v fr.Element + u.SetRandom() + v.SetRandom() + var P bn254.G1Affine + P.ScalarMultiplication(&G, u.BigInt(new(big.Int))) + var expected bn254.G1Affine + expected.ScalarMultiplication(&P, v.BigInt(new(big.Int))) + circuit := ecmulCircuit{} + witness := ecmulCircuit{ + X0: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](P.X), + Y: emulated.ValueOf[emulated.BN254Fp](P.Y), + }, + U: emulated.ValueOf[emulated.BN254Fr](v), + Expected: sw_emulated.AffinePoint[emulated.BN254Fp]{ + X: emulated.ValueOf[emulated.BN254Fp](expected.X), + Y: emulated.ValueOf[emulated.BN254Fp](expected.Y), + }, + } + return &circuit, &witness +} + +func TestECMulCircuitShort(t *testing.T) { + assert := test.NewAssert(t) + circuit, witness := testRoutineECMul(t) + err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestECMulCircuitFull(t *testing.T) { + t.Skip("skipping very long test") + assert := test.NewAssert(t) + circuit, witness := testRoutineECMul(t) + assert.ProverSucceeded(circuit, witness, + test.NoFuzzing(), test.NoSerialization(), + test.WithBackends(backend.GROTH16, backend.PLONK), test.WithCurves(ecc.BN254), + ) +} + +type ecPairBatchCircuit struct { + P sw_bn254.G1Affine + NP sw_bn254.G1Affine + DP sw_bn254.G1Affine + Q sw_bn254.G2Affine + n int +} + +func (c *ecPairBatchCircuit) Define(api frontend.API) error { + Q := make([]*sw_bn254.G2Affine, c.n) + for i := range Q { + Q[i] = &c.Q + } + switch c.n { + case 2: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP}, Q) + case 3: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.NP, &c.NP, &c.DP}, Q) + case 4: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP, &c.P, &c.NP}, Q) + case 5: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP, &c.NP, &c.NP, &c.DP}, Q) + case 6: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP, &c.P, &c.NP, &c.P, &c.NP}, Q) + case 7: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP, &c.P, &c.NP, &c.NP, &c.NP, &c.DP}, Q) + case 8: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP, &c.P, &c.NP, &c.P, &c.NP, &c.P, &c.NP}, Q) + case 9: + ECPair(api, []*sw_emulated.AffinePoint[emulated.BN254Fp]{&c.P, &c.NP, &c.P, &c.NP, &c.P, &c.NP, &c.NP, &c.NP, &c.DP}, Q) + default: + return fmt.Errorf("not handled %d", c.n) + } + return nil +} + +func TestECPairMulBatch(t *testing.T) { + assert := test.NewAssert(t) + _, _, p, q := bn254.Generators() + + var u, v fr.Element + u.SetRandom() + v.SetRandom() + + p.ScalarMultiplication(&p, u.BigInt(new(big.Int))) + q.ScalarMultiplication(&q, v.BigInt(new(big.Int))) + + var dp, np bn254.G1Affine + dp.Double(&p) + np.Neg(&p) + + for i := 2; i < 10; i++ { + err := test.IsSolved(&ecPairBatchCircuit{n: i}, &ecPairBatchCircuit{ + n: i, + P: sw_bn254.NewG1Affine(p), + NP: sw_bn254.NewG1Affine(np), + DP: sw_bn254.NewG1Affine(dp), + Q: sw_bn254.NewG2Affine(q), + }, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} diff --git a/std/evmprecompiles/compose.go b/std/evmprecompiles/compose.go new file mode 100644 index 0000000000..c1ffaff2aa --- /dev/null +++ b/std/evmprecompiles/compose.go @@ -0,0 +1,37 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" +) + +func recompose(inputs []*big.Int, nbBits uint) *big.Int { + res := new(big.Int) + if len(inputs) == 0 { + return res + } + for i := range inputs { + res.Lsh(res, nbBits) + res.Add(res, inputs[len(inputs)-i-1]) + } + return res +} + +func decompose(input *big.Int, nbBits uint, res []*big.Int) error { + // limb modulus + if input.BitLen() > len(res)*int(nbBits) { + return fmt.Errorf("decomposed integer does not fit into res") + } + for _, r := range res { + if r == nil { + return fmt.Errorf("result slice element uninitalized") + } + } + base := new(big.Int).Lsh(big.NewInt(1), nbBits) + tmp := new(big.Int).Set(input) + for i := 0; i < len(res); i++ { + res[i].Mod(tmp, base) + tmp.Rsh(tmp, nbBits) + } + return nil +} diff --git a/std/evmprecompiles/doc.go b/std/evmprecompiles/doc.go new file mode 100644 index 0000000000..7c515eaa51 --- /dev/null +++ b/std/evmprecompiles/doc.go @@ -0,0 +1,18 @@ +// Package evmprecompiles implements the Ethereum VM precompile contracts. +// +// This package collects all the precompile functions into a single location for +// easier integration. The main functionality is implemented elsewhere. This +// package right now implements: +// 1. ECRECOVER ✅ -- function [ECRecover] +// 2. SHA256 ❌ -- in progress +// 3. RIPEMD160 ❌ -- postponed +// 4. ID ❌ -- trivial to implement without function +// 5. EXPMOD ❌ -- in progress +// 6. BN_ADD ✅ -- function [ECAdd] +// 7. BN_MUL ✅ -- function [ECMul] +// 8. SNARKV ✅ -- function [ECPair] +// 9. BLAKE2F ❌ -- postponed +// +// This package uses local representation for the arguments. It is up to the +// user to instantiate corresponding types from their application-specific data. +package evmprecompiles diff --git a/std/evmprecompiles/hints.go b/std/evmprecompiles/hints.go new file mode 100644 index 0000000000..0ebc231ae0 --- /dev/null +++ b/std/evmprecompiles/hints.go @@ -0,0 +1,97 @@ +package evmprecompiles + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/secp256k1/ecdsa" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all the hints used in this package. +func GetHints() []solver.Hint { + return []solver.Hint{recoverPointHint, recoverPublicKeyHint} +} + +func recoverPointHintArgs(v frontend.Variable, r emulated.Element[emulated.Secp256k1Fr]) []frontend.Variable { + args := []frontend.Variable{v} + args = append(args, r.Limbs...) + return args +} + +func recoverPointHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + var emfp emulated.Secp256k1Fp + if len(inputs) != int(emfp.NbLimbs())+1 { + return fmt.Errorf("expected input %d limbs got %d", emfp.NbLimbs()+1, len(inputs)) + } + if !inputs[0].IsInt64() { + return fmt.Errorf("first input supposed to be in [0,3]") + } + if len(outputs) != 2*int(emfp.NbLimbs()) { + return fmt.Errorf("expected output %d limbs got %d", 2*emfp.NbLimbs(), len(outputs)) + } + v := inputs[0].Uint64() + r := recompose(inputs[1:], emfp.BitsPerLimb()) + P, err := ecdsa.RecoverP(uint(v), r) + if err != nil { + return fmt.Errorf("recover: %s", err) + } + if err := decompose(P.X.BigInt(new(big.Int)), emfp.BitsPerLimb(), outputs[0:emfp.NbLimbs()]); err != nil { + return fmt.Errorf("decompose x: %w", err) + } + if err := decompose(P.Y.BigInt(new(big.Int)), emfp.BitsPerLimb(), outputs[emfp.NbLimbs():]); err != nil { + return fmt.Errorf("decompose y: %w", err) + } + return nil +} + +func recoverPublicKeyHintArgs(msg emulated.Element[emulated.Secp256k1Fr], + v frontend.Variable, r, s emulated.Element[emulated.Secp256k1Fr]) []frontend.Variable { + args := msg.Limbs + args = append(args, v) + args = append(args, r.Limbs...) + args = append(args, s.Limbs...) + return args +} + +func recoverPublicKeyHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + // message -nb limbs + // then v - 1 + // r -- nb limbs + // s -- nb limbs + // return 2x nb limbs + var emfr emulated.Secp256k1Fr + var emfp emulated.Secp256k1Fp + if len(inputs) != int(emfr.NbLimbs())*3+1 { + return fmt.Errorf("expected %d limbs got %d", emfr.NbLimbs()*3+1, len(inputs)) + } + if !inputs[emfr.NbLimbs()].IsInt64() { + return fmt.Errorf("second input input must be in [0,3]") + } + if len(outputs) != 2*int(emfp.NbLimbs()) { + return fmt.Errorf("expected output %d limbs got %d", 2*emfp.NbLimbs(), len(outputs)) + } + msg := recompose(inputs[:emfr.NbLimbs()], emfr.BitsPerLimb()) + v := inputs[emfr.NbLimbs()].Uint64() + r := recompose(inputs[emfr.NbLimbs()+1:2*emfr.NbLimbs()+1], emfr.BitsPerLimb()) + s := recompose(inputs[2*emfr.NbLimbs()+1:3*emfr.NbLimbs()+1], emfr.BitsPerLimb()) + var pk ecdsa.PublicKey + if err := pk.RecoverFrom(msg.Bytes(), uint(v), r, s); err != nil { + return fmt.Errorf("recover public key: %w", err) + } + Px := pk.A.X.BigInt(new(big.Int)) + Py := pk.A.Y.BigInt(new(big.Int)) + if err := decompose(Px, emfp.BitsPerLimb(), outputs[0:emfp.NbLimbs()]); err != nil { + return fmt.Errorf("decompose x: %w", err) + } + if err := decompose(Py, emfp.BitsPerLimb(), outputs[emfp.NbLimbs():2*emfp.NbLimbs()]); err != nil { + return fmt.Errorf("decompose y: %w", err) + } + return nil +} diff --git a/std/fiat-shamir/settings.go b/std/fiat-shamir/settings.go index 7b1f2a31a7..146a64355e 100644 --- a/std/fiat-shamir/settings.go +++ b/std/fiat-shamir/settings.go @@ -9,7 +9,7 @@ type Settings struct { Transcript *Transcript Prefix string BaseChallenges []frontend.Variable - Hash hash.Hash + Hash hash.FieldHasher } func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings { @@ -20,7 +20,7 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro } } -func WithHash(hash hash.Hash, baseChallenges ...frontend.Variable) Settings { +func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings { return Settings{ BaseChallenges: baseChallenges, Hash: hash, diff --git a/std/fiat-shamir/transcript.go b/std/fiat-shamir/transcript.go index 3ce9af5d98..63fe8e786d 100644 --- a/std/fiat-shamir/transcript.go +++ b/std/fiat-shamir/transcript.go @@ -18,6 +18,7 @@ package fiatshamir import ( "errors" + "github.com/consensys/gnark/constant" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" @@ -33,7 +34,7 @@ var ( // Transcript handles the creation of challenges for Fiat Shamir. type Transcript struct { // hash function that is used. - h hash.Hash + h hash.FieldHasher challenges map[string]challenge previous *challenge @@ -52,7 +53,7 @@ type challenge struct { // NewTranscript returns a new transcript. // h is the hash function that is used to compute the challenges. // challenges are the name of the challenges. The order is important. -func NewTranscript(api frontend.API, h hash.Hash, challengesID ...string) Transcript { +func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID ...string) Transcript { n := len(challengesID) t := Transcript{ challenges: make(map[string]challenge, n), diff --git a/std/gkr/api.go b/std/gkr/api.go new file mode 100644 index 0000000000..d13c65e2d0 --- /dev/null +++ b/std/gkr/api.go @@ -0,0 +1,50 @@ +package gkr + +import ( + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/std/utils/algo_utils" +) + +func frontendVarToInt(a constraint.GkrVariable) int { + return int(a) +} + +func (api *API) NamedGate(gate string, in ...constraint.GkrVariable) constraint.GkrVariable { + api.toStore.Circuit = append(api.toStore.Circuit, constraint.GkrWire{ + Gate: gate, + Inputs: algo_utils.Map(in, frontendVarToInt), + }) + api.assignments = append(api.assignments, nil) + return constraint.GkrVariable(len(api.toStore.Circuit) - 1) +} + +func (api *API) namedGate2PlusIn(gate string, in1, in2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + inCombined := make([]constraint.GkrVariable, 2+len(in)) + inCombined[0] = in1 + inCombined[1] = in2 + for i := range in { + inCombined[i+2] = in[i] + } + return api.NamedGate(gate, inCombined...) +} + +func (api *API) Add(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + return api.namedGate2PlusIn("add", i1, i2, in...) +} + +func (api *API) Neg(i1 constraint.GkrVariable) constraint.GkrVariable { + return api.NamedGate("neg", i1) +} + +func (api *API) Sub(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + return api.namedGate2PlusIn("sub", i1, i2, in...) +} + +func (api *API) Mul(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + return api.namedGate2PlusIn("mul", i1, i2, in...) +} + +// TODO @Tabaie This can be useful +func (api *API) Println(a ...constraint.GkrVariable) { + panic("not implemented") +} diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go new file mode 100644 index 0000000000..d39b103962 --- /dev/null +++ b/std/gkr/api_test.go @@ -0,0 +1,663 @@ +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark/backend/plonk" + bn254r1cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" + "hash" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" + bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + stdHash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/hash/mimc" + test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" +) + +// compressThreshold --> if linear expressions are larger than this, the frontend will introduce +// intermediate constraints. The lower this number is, the faster compile time should be (to a point) +// but resulting circuit will have more constraints (slower proving time). +const compressThreshold = 1000 + +type doubleNoDependencyCircuit struct { + X []frontend.Variable + hashName string +} + +func (c *doubleNoDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x constraint.GkrVariable + var err error + if x, err = gkr.Import(c.X); err != nil { + return err + } + z := gkr.Add(x, x) + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + Z := solution.Export(z) + + for i := range Z { + api.AssertIsEqual(Z[i], api.Mul(2, c.X[i])) + } + + return solution.Verify(c.hashName) +} + +func TestDoubleNoDependencyCircuit(t *testing.T) { + + xValuess := [][]frontend.Variable{ + {1, 1}, + {1, 2}, + } + + hashes := []string{"-1", "-20"} + + for _, xValues := range xValuess { + for _, hashName := range hashes { + assignment := doubleNoDependencyCircuit{X: xValues} + circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName} + + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) + } + } +} + +type sqNoDependencyCircuit struct { + X []frontend.Variable + hashName string +} + +func (c *sqNoDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x constraint.GkrVariable + var err error + if x, err = gkr.Import(c.X); err != nil { + return err + } + z := gkr.Mul(x, x) + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + Z := solution.Export(z) + + for i := range Z { + api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.X[i])) + } + + return solution.Verify(c.hashName) +} + +func TestSqNoDependencyCircuit(t *testing.T) { + + xValuess := [][]frontend.Variable{ + {1, 1}, + {1, 2}, + } + + hashes := []string{"-1", "-20"} + + for _, xValues := range xValuess { + for _, hashName := range hashes { + assignment := sqNoDependencyCircuit{X: xValues} + circuit := sqNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName} + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) + } + } +} + +type mulNoDependencyCircuit struct { + X, Y []frontend.Variable + hashName string +} + +func (c *mulNoDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x, y constraint.GkrVariable + var err error + if x, err = gkr.Import(c.X); err != nil { + return err + } + if y, err = gkr.Import(c.Y); err != nil { + return err + } + z := gkr.Mul(x, y) + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + X := solution.Export(x) + Y := solution.Export(y) + Z := solution.Export(z) + api.Println("after solving, z=", Z, ", x=", X, ", y=", Y) + + for i := range c.X { + api.Println("z@", i, " = ", Z[i]) + api.Println("x.y = ", api.Mul(c.X[i], c.Y[i])) + api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.Y[i])) + } + + return solution.Verify(c.hashName) +} + +func TestMulNoDependency(t *testing.T) { + xValuess := [][]frontend.Variable{ + {1, 2}, + } + yValuess := [][]frontend.Variable{ + {0, 3}, + } + + hashes := []string{"-1", "-20"} + + for i := range xValuess { + for _, hashName := range hashes { + + assignment := mulNoDependencyCircuit{ + X: xValuess[i], + Y: yValuess[i], + } + circuit := mulNoDependencyCircuit{ + X: make([]frontend.Variable, len(xValuess[i])), + Y: make([]frontend.Variable, len(yValuess[i])), + hashName: hashName, + } + + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) + } + } +} + +type mulWithDependencyCircuit struct { + XLast frontend.Variable + Y []frontend.Variable + hashName string +} + +func (c *mulWithDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x, y constraint.GkrVariable + var err error + + X := make([]frontend.Variable, len(c.Y)) + X[len(c.Y)-1] = c.XLast + if x, err = gkr.Import(X); err != nil { + return err + } + if y, err = gkr.Import(c.Y); err != nil { + return err + } + z := gkr.Mul(x, y) + + for i := len(X) - 1; i > 0; i-- { + gkr.Series(x, z, i-1, i) + } + + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + X = solution.Export(x) + Y := solution.Export(y) + Z := solution.Export(z) + + api.Println("after solving, z=", Z, ", x=", X, ", y=", Y) + + lastI := len(X) - 1 + api.AssertIsEqual(Z[lastI], api.Mul(c.XLast, Y[lastI])) + for i := 0; i < lastI; i++ { + api.AssertIsEqual(Z[i], api.Mul(Z[i+1], Y[i])) + } + return solution.Verify(c.hashName) +} + +func TestSolveMulWithDependency(t *testing.T) { + assignment := mulWithDependencyCircuit{ + XLast: 1, + Y: []frontend.Variable{3, 2}, + } + circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"} + + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) +} + +func TestApiMul(t *testing.T) { + var ( + x constraint.GkrVariable + y constraint.GkrVariable + z constraint.GkrVariable + err error + ) + api := NewApi() + x, err = api.Import([]frontend.Variable{nil, nil}) + require.NoError(t, err) + y, err = api.Import([]frontend.Variable{nil, nil}) + require.NoError(t, err) + z = api.Mul(x, y) + test_vector_utils.AssertSliceEqual(t, api.toStore.Circuit[z].Inputs, []int{int(x), int(y)}) // TODO: Find out why assert.Equal gives false positives ( []*Wire{x,x} as second argument passes when it shouldn't ) +} + +func BenchmarkMiMCMerkleTree(b *testing.B) { + depth := 14 + //fmt.Println("start") + bottom := make([]frontend.Variable, 1<= 0; d-- { + for i := 0; i < 1<= 2 { - rest = v[2:] + switch len(v) { + case 0: + return 0 + case 1: + return v[0] } + rest := v[2:] return api.Add(v[0], v[1], rest...) } func (a AddGate) Degree() int { return 1 } + +var Gates = map[string]Gate{ + "identity": IdentityGate{}, + "add": AddGate{}, + "mul": MulGate{}, +} diff --git a/std/gkr/gkr_test.go b/std/gkr/gkr_test.go index 29060c59b6..a16165a40d 100644 --- a/std/gkr/gkr_test.go +++ b/std/gkr/gkr_test.go @@ -11,10 +11,11 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" - "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/polynomial" "github.com/consensys/gnark/test" "github.com/stretchr/testify/assert" + + "github.com/consensys/gnark/std/hash" ) func TestGkrVectors(t *testing.T) { @@ -110,7 +111,7 @@ func (c *GkrVerifierCircuit) Define(api frontend.API) error { } assignment := makeInOutAssignment(testCase.Circuit, c.Input, c.Output) - var hsh hash.Hash + var hsh hash.FieldHasher if c.ToFail { hsh = NewMessageCounter(api, 1, 1) } else { @@ -240,7 +241,7 @@ func (c CircuitInfo) toCircuit() (circuit Circuit, err error) { } var found bool - if circuit[i].Gate, found = RegisteredGates[wireInfo.Gate]; !found && wireInfo.Gate != "" { + if circuit[i].Gate, found = Gates[wireInfo.Gate]; !found && wireInfo.Gate != "" { err = fmt.Errorf("undefined gate \"%s\"", wireInfo.Gate) } } @@ -251,7 +252,7 @@ func (c CircuitInfo) toCircuit() (circuit Circuit, err error) { type _select int func init() { - RegisteredGates["select-input-3"] = _select(2) + Gates["select-input-3"] = _select(2) } func (g _select) Evaluate(_ frontend.API, in ...frontend.Variable) frontend.Variable { @@ -414,7 +415,7 @@ func SliceEqual[T comparable](expected, seen []T) bool { type HashDescription map[string]interface{} -func HashFromDescription(api frontend.API, d HashDescription) (hash.Hash, error) { +func HashFromDescription(api frontend.API, d HashDescription) (hash.FieldHasher, error) { if _type, ok := d["type"]; ok { switch _type { case "const": @@ -456,13 +457,13 @@ func (m *MessageCounter) Reset() { m.state = m.startState } -func NewMessageCounter(api frontend.API, startState, step int) hash.Hash { +func NewMessageCounter(api frontend.API, startState, step int) hash.FieldHasher { transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step), api: api} return transcript } -func NewMessageCounterGenerator(startState, step int) func(frontend.API) hash.Hash { - return func(api frontend.API) hash.Hash { +func NewMessageCounterGenerator(startState, step int) func(frontend.API) hash.FieldHasher { + return func(api frontend.API) hash.FieldHasher { return NewMessageCounter(api, startState, step) } } @@ -482,3 +483,25 @@ func (c *constHashCircuit) Define(api frontend.API) error { func TestConstHash(t *testing.T) { test.NewAssert(t).SolvingSucceeded(&constHashCircuit{}, &constHashCircuit{X: 1}) } + +var mimcSnarkTotalCalls = 0 + +type MiMCCipherGate struct { + Ark frontend.Variable +} + +func (m MiMCCipherGate) Evaluate(api frontend.API, input ...frontend.Variable) frontend.Variable { + mimcSnarkTotalCalls++ + + if len(input) != 2 { + panic("mimc has fan-in 2") + } + sum := api.Add(input[0], input[1], m.Ark) + + sumCubed := api.Mul(sum, sum, sum) // sum^3 + return api.Mul(sumCubed, sumCubed, sum) +} + +func (m MiMCCipherGate) Degree() int { + return 7 +} diff --git a/std/gkr/registry.go b/std/gkr/registry.go deleted file mode 100644 index 07e60abc91..0000000000 --- a/std/gkr/registry.go +++ /dev/null @@ -1,29 +0,0 @@ -package gkr - -import "github.com/consensys/gnark/frontend" - -var RegisteredGates = map[string]Gate{ - "identity": IdentityGate{}, - "add": AddGate{}, - "mul": MulGate{}, - "mimc": MiMCCipherGate{Ark: 0}, //TODO: Add ark -} - -type MiMCCipherGate struct { - Ark frontend.Variable -} - -func (m MiMCCipherGate) Evaluate(api frontend.API, input ...frontend.Variable) frontend.Variable { - - if len(input) != 2 { - panic("mimc has fan-in 2") - } - sum := api.Add(input[0], input[1], m.Ark) - - sumCubed := api.Mul(sum, sum, sum) // sum^3 - return api.Mul(sumCubed, sumCubed, sum) -} - -func (m MiMCCipherGate) Degree() int { - return 7 -} diff --git a/std/groth16_bls12377/verifier.go b/std/groth16_bls12377/verifier.go index 6a43e585a8..e2f7f37338 100644 --- a/std/groth16_bls12377/verifier.go +++ b/std/groth16_bls12377/verifier.go @@ -1,5 +1,5 @@ /* -Copyright © 2020 ConsenSys +Copyright 2020 ConsenSys Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,10 +22,10 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/backend/groth16" + groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377" "github.com/consensys/gnark/frontend" - groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - "github.com/consensys/gnark/std/algebra/fields_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/fields_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" ) // Proof represents a Groth16 proof @@ -49,6 +49,7 @@ type VerifyingKey struct { // [Kvk]1 G1 struct { K []sw_bls12377.G1Affine // The indexes correspond to the public wires + } } @@ -57,7 +58,8 @@ type VerifyingKey struct { // publicInputs do NOT contain the ONE_WIRE func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []frontend.Variable) { if len(vk.G1.K) == 0 { - panic("innver verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + panic("inner verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + } // compute kSum = Σx.[Kvk(t)]1 @@ -71,14 +73,15 @@ func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []front var ki sw_bls12377.G1Affine ki.ScalarMul(api, vk.G1.K[k+1], v) kSum.AddAssign(api, ki) + } // compute e(Σx.[Kvk(t)]1, -[γ]2) * e(Krs,δ) * e(Ar,Bs) - ml, _ := sw_bls12377.MillerLoop(api, []sw_bls12377.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls12377.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) - pairing := sw_bls12377.FinalExponentiation(api, ml) + pairing, _ := sw_bls12377.Pair(api, []sw_bls12377.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls12377.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) // vk.E must be equal to pairing vk.E.AssertIsEqual(api, pairing) + } // Assign values to the "in-circuit" VerifyingKey from a "out-of-circuit" VerifyingKey @@ -86,21 +89,51 @@ func (vk *VerifyingKey) Assign(_ovk groth16.VerifyingKey) { ovk, ok := _ovk.(*groth16_bls12377.VerifyingKey) if !ok { panic("expected *groth16_bls12377.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + } e, err := bls12377.Pair([]bls12377.G1Affine{ovk.G1.Alpha}, []bls12377.G2Affine{ovk.G2.Beta}) if err != nil { panic(err) + } vk.E.Assign(&e) vk.G1.K = make([]sw_bls12377.G1Affine, len(ovk.G1.K)) for i := 0; i < len(ovk.G1.K); i++ { vk.G1.K[i].Assign(&ovk.G1.K[i]) + } var deltaNeg, gammaNeg bls12377.G2Affine deltaNeg.Neg(&ovk.G2.Delta) gammaNeg.Neg(&ovk.G2.Gamma) vk.G2.DeltaNeg.Assign(&deltaNeg) vk.G2.GammaNeg.Assign(&gammaNeg) + +} + +// Allocate memory for the "in-circuit" VerifyingKey +// This is exposed so that the slices in the structure can be allocated +// before calling frontend.Compile(). +func (vk *VerifyingKey) Allocate(_ovk groth16.VerifyingKey) { + ovk, ok := _ovk.(*groth16_bls12377.VerifyingKey) + if !ok { + panic("expected *groth16_bls12377.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + + } + vk.G1.K = make([]sw_bls12377.G1Affine, len(ovk.G1.K)) + +} + +// Assign the proof values of Groth16 +func (proof *Proof) Assign(_oproof groth16.Proof) { + oproof, ok := _oproof.(*groth16_bls12377.Proof) + if !ok { + panic("expected *groth16_bls12377.Proof, got " + reflect.TypeOf(oproof).String()) + + } + proof.Ar.Assign(&oproof.Ar) + proof.Krs.Assign(&oproof.Krs) + proof.Bs.Assign(&oproof.Bs) + } diff --git a/std/groth16_bls12377/verifier_test.go b/std/groth16_bls12377/verifier_test.go index 10bcd53d8b..282984139e 100644 --- a/std/groth16_bls12377/verifier_test.go +++ b/std/groth16_bls12377/verifier_test.go @@ -1,5 +1,5 @@ /* -Copyright © 2020 ConsenSys +Copyright 2020 ConsenSys Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,13 +22,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" - cs_bls12377 "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - "github.com/consensys/gnark/std/algebra/sw_bls12377" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) @@ -47,125 +46,161 @@ func (circuit *mimcCircuit) Define(api frontend.API) error { mimc, err := mimc.NewMiMC(api) if err != nil { return err + } mimc.Write(circuit.PreImage) api.AssertIsEqual(mimc.Sum(), circuit.Hash) return nil + +} + +// Calculate the expected output of MIMC through plain invocation +func preComputeMimc(preImage frontend.Variable) interface{} { + var expectedY fr.Element + expectedY.SetInterface(preImage) + // calc MiMC + goMimc := hash.MIMC_BLS12_377.New() + goMimc.Write(expectedY.Marshal()) + expectedh := goMimc.Sum(nil) + return expectedh + +} + +type verifierCircuit struct { + InnerProof Proof + InnerVk VerifyingKey + Hash frontend.Variable } -// Prepare the data for the inner proof. -// Returns the public inputs string of the inner proof -func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, proof *groth16_bls12377.Proof) { +func (circuit *verifierCircuit) Define(api frontend.API) error { + // create the verifier cs + Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + + return nil + +} + +func TestVerifier(t *testing.T) { // create a mock cs: knowing the preimage of a hash using mimc - var circuit mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, &circuit) + var MimcCircuit mimcCircuit + r1cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, &MimcCircuit) if err != nil { t.Fatal(err) - } - // build the witness - var assignment mimcCircuit - assignment.PreImage = preImage - assignment.Hash = publicHash + } - witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS12_377.ScalarField()) if err != nil { t.Fatal(err) + } - publicWitness, err := witness.Public() + innerPk, innerVk, err := groth16.Setup(r1cs) if err != nil { t.Fatal(err) + } - // generate the data to return for the bls12377 proof - var pk groth16_bls12377.ProvingKey - err = groth16_bls12377.Setup(r1cs.(*cs_bls12377.R1CS), &pk, vk) + proof, err := groth16.Prove(r1cs, innerPk, pre_witness) if err != nil { t.Fatal(err) + } - _proof, err := groth16_bls12377.Prove(r1cs.(*cs_bls12377.R1CS), &pk, witness.Vector().(fr.Vector), backend.ProverConfig{}) + publicWitness, err := pre_witness.Public() if err != nil { t.Fatal(err) + } - proof.Ar = _proof.Ar - proof.Bs = _proof.Bs - proof.Krs = _proof.Krs - // before returning verifies that the proof passes on bls12377 - if err := groth16_bls12377.Verify(proof, vk, publicWitness.Vector().(fr.Vector)); err != nil { + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { t.Fatal(err) + } -} + var circuit verifierCircuit + circuit.InnerVk.Allocate(innerVk) -type verifierCircuit struct { - InnerProof Proof - InnerVk VerifyingKey - Hash frontend.Variable -} + var witness verifierCircuit + witness.InnerProof.Assign(proof) + witness.InnerVk.Assign(innerVk) + witness.Hash = preComputeMimc(preImage) -func (circuit *verifierCircuit) Define(api frontend.API) error { - // create the verifier cs - Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + assert := test.NewAssert(t) + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761), test.WithBackends(backend.GROTH16)) - return nil } -func TestVerifier(t *testing.T) { +func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls12377.VerifyingKey - var innerProof groth16_bls12377.Proof - generateBls12377InnerProof(t, &innerVk, &innerProof) // get public inputs of the inner proof + // create a mock cs: knowing the preimage of a hash using mimc + var MimcCircuit mimcCircuit + _r1cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, &MimcCircuit) + if err != nil { + b.Fatal(err) - // create an empty cs - var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls12377.G1Affine, len(innerVk.G1.K)) + } - // create assignment, the private part consists of the proof, - // the public part is exactly the public part of the inner proof, - // up to the renaming of the inner ONE_WIRE to not conflict with the one wire of the outer proof. - var witness verifierCircuit - witness.InnerProof.Ar.Assign(&innerProof.Ar) - witness.InnerProof.Krs.Assign(&innerProof.Krs) - witness.InnerProof.Bs.Assign(&innerProof.Bs) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS12_377.ScalarField()) + if err != nil { + b.Fatal(err) - witness.InnerVk.Assign(&innerVk) - witness.Hash = publicHash + } - // verifies the cs - assert := test.NewAssert(t) + innerPk, innerVk, err := groth16.Setup(_r1cs) + if err != nil { + b.Fatal(err) - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761)) -} + } -func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls12377.VerifyingKey - var innerProof groth16_bls12377.Proof - generateBls12377InnerProof(nil, &innerVk, &innerProof) // get public inputs of the inner proof + proof, err := groth16.Prove(_r1cs, innerPk, pre_witness) + if err != nil { + b.Fatal(err) + + } + + publicWitness, err := pre_witness.Public() + if err != nil { + b.Fatal(err) + + } + + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { + b.Fatal(err) + + } - // create an empty cs var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls12377.G1Affine, len(innerVk.G1.K)) + circuit.InnerVk.Allocate(innerVk) var ccs constraint.ConstraintSystem - var err error b.ResetTimer() for i := 0; i < b.N; i++ { ccs, err = frontend.Compile(ecc.BW6_761.ScalarField(), r1cs.NewBuilder, &circuit) if err != nil { b.Fatal(err) + } + } + b.Log(ccs.GetNbConstraints()) + } var tVariable reflect.Type func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() + } diff --git a/std/groth16_bls24315/verifier.go b/std/groth16_bls24315/verifier.go index 09487acbb3..8e474a8979 100644 --- a/std/groth16_bls24315/verifier.go +++ b/std/groth16_bls24315/verifier.go @@ -22,10 +22,10 @@ import ( bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark/backend/groth16" + groth16_bls24315 "github.com/consensys/gnark/backend/groth16/bls24-315" "github.com/consensys/gnark/frontend" - groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - "github.com/consensys/gnark/std/algebra/fields_bls24315" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/std/algebra/native/fields_bls24315" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" ) // Proof represents a Groth16 proof @@ -49,6 +49,7 @@ type VerifyingKey struct { // [Kvk]1 G1 struct { K []sw_bls24315.G1Affine // The indexes correspond to the public wires + } } @@ -57,7 +58,8 @@ type VerifyingKey struct { // publicInputs do NOT contain the ONE_WIRE func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []frontend.Variable) { if len(vk.G1.K) == 0 { - panic("innver verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + panic("inner verifying key needs at least one point; VerifyingKey.G1 must be initialized before compiling circuit") + } // compute kSum = Σx.[Kvk(t)]1 @@ -71,11 +73,11 @@ func Verify(api frontend.API, vk VerifyingKey, proof Proof, publicInputs []front var ki sw_bls24315.G1Affine ki.ScalarMul(api, vk.G1.K[k+1], v) kSum.AddAssign(api, ki) + } // compute e(Σx.[Kvk(t)]1, -[γ]2) * e(Krs,δ) * e(Ar,Bs) - ml, _ := sw_bls24315.MillerLoop(api, []sw_bls24315.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls24315.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) - pairing := sw_bls24315.FinalExponentiation(api, ml) + pairing, _ := sw_bls24315.Pair(api, []sw_bls24315.G1Affine{kSum, proof.Krs, proof.Ar}, []sw_bls24315.G2Affine{vk.G2.GammaNeg, vk.G2.DeltaNeg, proof.Bs}) // vk.E must be equal to pairing vk.E.AssertIsEqual(api, pairing) @@ -87,21 +89,51 @@ func (vk *VerifyingKey) Assign(_ovk groth16.VerifyingKey) { ovk, ok := _ovk.(*groth16_bls24315.VerifyingKey) if !ok { panic("expected *groth16_bls24315.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + } e, err := bls24315.Pair([]bls24315.G1Affine{ovk.G1.Alpha}, []bls24315.G2Affine{ovk.G2.Beta}) if err != nil { panic(err) + } vk.E.Assign(&e) vk.G1.K = make([]sw_bls24315.G1Affine, len(ovk.G1.K)) for i := 0; i < len(ovk.G1.K); i++ { vk.G1.K[i].Assign(&ovk.G1.K[i]) + } var deltaNeg, gammaNeg bls24315.G2Affine deltaNeg.Neg(&ovk.G2.Delta) gammaNeg.Neg(&ovk.G2.Gamma) vk.G2.DeltaNeg.Assign(&deltaNeg) vk.G2.GammaNeg.Assign(&gammaNeg) + +} + +// Allocate memory for the "in-circuit" VerifyingKey +// This is exposed so that the slices in the structure can be allocated +// before calling frontend.Compile(). +func (vk *VerifyingKey) Allocate(_ovk groth16.VerifyingKey) { + ovk, ok := _ovk.(*groth16_bls24315.VerifyingKey) + if !ok { + panic("expected *groth16_bls24315.VerifyingKey, got " + reflect.TypeOf(_ovk).String()) + + } + vk.G1.K = make([]sw_bls24315.G1Affine, len(ovk.G1.K)) + +} + +// Assign the proof values of Groth16 +func (proof *Proof) Assign(_oproof groth16.Proof) { + oproof, ok := _oproof.(*groth16_bls24315.Proof) + if !ok { + panic("expected *groth16_bls24315.Proof, got " + reflect.TypeOf(oproof).String()) + + } + proof.Ar.Assign(&oproof.Ar) + proof.Krs.Assign(&oproof.Krs) + proof.Bs.Assign(&oproof.Bs) + } diff --git a/std/groth16_bls24315/verifier_test.go b/std/groth16_bls24315/verifier_test.go index 7144014e8d..c3a5a9f3af 100644 --- a/std/groth16_bls24315/verifier_test.go +++ b/std/groth16_bls24315/verifier_test.go @@ -22,20 +22,16 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" - cs_bls24315 "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) -//-------------------------------------------------------------------- -// utils - const ( preImage = "4992816046196248432836492760315135318126925090839638585255611512962528270024" publicHash = "4875439939758844840941638351757981379945701574516438614845550995673793857363" @@ -50,121 +46,161 @@ func (circuit *mimcCircuit) Define(api frontend.API) error { mimc, err := mimc.NewMiMC(api) if err != nil { return err + } mimc.Write(circuit.PreImage) api.AssertIsEqual(mimc.Sum(), circuit.Hash) return nil + +} + +// Calculate the expected output of MIMC through plain invocation +func preComputeMimc(preImage frontend.Variable) interface{} { + var expectedY fr.Element + expectedY.SetInterface(preImage) + // calc MiMC + goMimc := hash.MIMC_BLS24_315.New() + goMimc.Write(expectedY.Marshal()) + expectedh := goMimc.Sum(nil) + return expectedh + +} + +type verifierCircuit struct { + InnerProof Proof + InnerVk VerifyingKey + Hash frontend.Variable } -// Prepare the data for the inner proof. -// Returns the public inputs string of the inner proof -func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, proof *groth16_bls24315.Proof) { +func (circuit *verifierCircuit) Define(api frontend.API) error { + // create the verifier cs + Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + + return nil + +} + +func TestVerifier(t *testing.T) { // create a mock cs: knowing the preimage of a hash using mimc - var circuit, assignment mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS24_315.ScalarField(), r1cs.NewBuilder, &circuit) + var MimcCircuit mimcCircuit + r1cs, err := frontend.Compile(ecc.BLS24_315.ScalarField(), r1cs.NewBuilder, &MimcCircuit) if err != nil { t.Fatal(err) - } - assignment.PreImage = preImage - assignment.Hash = publicHash + } - witness, err := frontend.NewWitness(&assignment, ecc.BLS24_315.ScalarField()) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS24_315.ScalarField()) if err != nil { t.Fatal(err) + } - publicWitness, err := witness.Public() + innerPk, innerVk, err := groth16.Setup(r1cs) if err != nil { t.Fatal(err) + } - // generate the data to return for the bls24315 proof - var pk groth16_bls24315.ProvingKey - err = groth16_bls24315.Setup(r1cs.(*cs_bls24315.R1CS), &pk, vk) + proof, err := groth16.Prove(r1cs, innerPk, pre_witness) if err != nil { t.Fatal(err) + } - _proof, err := groth16_bls24315.Prove(r1cs.(*cs_bls24315.R1CS), &pk, witness.Vector().(fr.Vector), backend.ProverConfig{}) + publicWitness, err := pre_witness.Public() if err != nil { t.Fatal(err) + } - proof.Ar = _proof.Ar - proof.Bs = _proof.Bs - proof.Krs = _proof.Krs - // before returning verifies that the proof passes on bls24315 - if err := groth16_bls24315.Verify(proof, vk, publicWitness.Vector().(fr.Vector)); err != nil { + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { t.Fatal(err) + } -} -type verifierCircuit struct { - InnerProof Proof - InnerVk VerifyingKey - Hash frontend.Variable -} + var circuit verifierCircuit + circuit.InnerVk.Allocate(innerVk) -func (circuit *verifierCircuit) Define(api frontend.API) error { + var witness verifierCircuit + witness.InnerProof.Assign(proof) + witness.InnerVk.Assign(innerVk) + witness.Hash = preComputeMimc(preImage) - // create the verifier cs - Verify(api, circuit.InnerVk, circuit.InnerProof, []frontend.Variable{circuit.Hash}) + assert := test.NewAssert(t) + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633), test.WithBackends(backend.GROTH16)) - return nil } -func TestVerifier(t *testing.T) { +func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls24315.VerifyingKey - var innerProof groth16_bls24315.Proof - generateBls24315InnerProof(t, &innerVk, &innerProof) // get public inputs of the inner proof + // create a mock cs: knowing the preimage of a hash using mimc + var MimcCircuit mimcCircuit + _r1cs, err := frontend.Compile(ecc.BLS24_315.ScalarField(), r1cs.NewBuilder, &MimcCircuit) + if err != nil { + b.Fatal(err) - // create an empty cs - var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls24315.G1Affine, len(innerVk.G1.K)) + } - // create assignment, the private part consists of the proof, - // the public part is exactly the public part of the inner proof, - // up to the renaming of the inner ONE_WIRE to not conflict with the one wire of the outer proof. - var witness verifierCircuit - witness.InnerProof.Ar.Assign(&innerProof.Ar) - witness.InnerProof.Krs.Assign(&innerProof.Krs) - witness.InnerProof.Bs.Assign(&innerProof.Bs) + var pre_assignment mimcCircuit + pre_assignment.PreImage = preImage + pre_assignment.Hash = publicHash + pre_witness, err := frontend.NewWitness(&pre_assignment, ecc.BLS24_315.ScalarField()) + if err != nil { + b.Fatal(err) + + } - witness.InnerVk.Assign(&innerVk) + innerPk, innerVk, err := groth16.Setup(_r1cs) + if err != nil { + b.Fatal(err) - witness.Hash = publicHash + } - // verifies the cs - assert := test.NewAssert(t) + proof, err := groth16.Prove(_r1cs, innerPk, pre_witness) + if err != nil { + b.Fatal(err) - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_633)) + } -} + publicWitness, err := pre_witness.Public() + if err != nil { + b.Fatal(err) -func BenchmarkCompile(b *testing.B) { - // get the data - var innerVk groth16_bls24315.VerifyingKey - var innerProof groth16_bls24315.Proof - generateBls24315InnerProof(nil, &innerVk, &innerProof) // get public inputs of the inner proof + } + + // Check that proof verifies before continuing + if err := groth16.Verify(proof, innerVk, publicWitness); err != nil { + b.Fatal(err) + + } - // create an empty cs var circuit verifierCircuit - circuit.InnerVk.G1.K = make([]sw_bls24315.G1Affine, len(innerVk.G1.K)) + circuit.InnerVk.Allocate(innerVk) var ccs constraint.ConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_633.ScalarField(), r1cs.NewBuilder, &circuit) + ccs, err = frontend.Compile(ecc.BW6_633.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + b.Fatal(err) + + } + } + b.Log(ccs.GetNbConstraints()) + } var tVariable reflect.Type func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() + } diff --git a/std/hash/hash.go b/std/hash/hash.go index 55e8791084..67e0d91ce1 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -17,16 +17,49 @@ limitations under the License. // Package hash provides an interface that hash functions (as gadget) should implement. package hash -import "github.com/consensys/gnark/frontend" - -type Hash interface { +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" +) +// FieldHasher hashes inputs into a short digest. This interface mocks +// [BinaryHasher], but is more suitable in-circuit by assuming the inputs are +// scalar field elements and outputs digest as a field element. Such hash +// functions are for examle Poseidon, MiMC etc. +type FieldHasher interface { // Sum computes the hash of the internal state of the hash function. Sum() frontend.Variable - // Write populate the internal state of the hash function with data. + // Write populate the internal state of the hash function with data. The inputs are native field elements. Write(data ...frontend.Variable) // Reset empty the internal state and put the intermediate state to zero. Reset() } + +var BuilderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error)) + +// BinaryHasher hashes inputs into a short digest. It takes as inputs bytes and +// outputs byte array whose length depends on the underlying hash function. For +// SNARK-native hash functions use [FieldHasher]. +type BinaryHasher interface { + // Sum finalises the current hash and returns the digest. + Sum() []uints.U8 + + // Write writes more bytes into the current hash state. + Write([]uints.U8) + + // Size returns the number of bytes this hash function returns in a call to + // [BinaryHasher.Sum]. + Size() int +} + +// BinaryFixedLengthHasher is like [BinaryHasher], but assumes the length of the +// input is not full length as defined during compile time. This allows to +// compute digest of variable-length input, unlike [BinaryHasher] which assumes +// the length of the input is the total number of bytes written. +type BinaryFixedLengthHasher interface { + BinaryHasher + // FixedLengthSum returns digest of the first length bytes. + FixedLengthSum(length frontend.Variable) []uints.U8 +} diff --git a/std/hash/sha2/sha2.go b/std/hash/sha2/sha2.go new file mode 100644 index 0000000000..edb261e621 --- /dev/null +++ b/std/hash/sha2/sha2.go @@ -0,0 +1,89 @@ +// Package sha2 implements SHA2 hash computation. +// +// This package extends the SHA2 permutation function [sha2] into a full SHA2 +// hash. +package sha2 + +import ( + "encoding/binary" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/std/permutation/sha2" +) + +var _seed = uints.NewU32Array([]uint32{ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +}) + +type digest struct { + uapi *uints.BinaryField[uints.U32] + in []uints.U8 +} + +func New(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U32](api) + if err != nil { + return nil, err + } + return &digest{uapi: uapi}, nil +} + +func (d *digest) Write(data []uints.U8) { + d.in = append(d.in, data...) +} + +func (d *digest) padded(bytesLen int) []uints.U8 { + zeroPadLen := 55 - bytesLen%64 + if zeroPadLen < 0 { + zeroPadLen += 64 + } + if cap(d.in) < len(d.in)+9+zeroPadLen { + // in case this is the first time this method is called increase the + // capacity of the slice to fit the padding. + d.in = append(d.in, make([]uints.U8, 9+zeroPadLen)...) + d.in = d.in[:len(d.in)-9-zeroPadLen] + } + buf := d.in + buf = append(buf, uints.NewU8(0x80)) + buf = append(buf, uints.NewU8Array(make([]uint8, zeroPadLen))...) + lenbuf := make([]uint8, 8) + binary.BigEndian.PutUint64(lenbuf, uint64(8*bytesLen)) + buf = append(buf, uints.NewU8Array(lenbuf)...) + return buf +} + +func (d *digest) Sum() []uints.U8 { + var runningDigest [8]uints.U32 + var buf [64]uints.U8 + copy(runningDigest[:], _seed) + padded := d.padded(len(d.in)) + for i := 0; i < len(padded)/64; i++ { + copy(buf[:], padded[i*64:(i+1)*64]) + runningDigest = sha2.Permute(d.uapi, runningDigest, buf) + } + var ret []uints.U8 + for i := range runningDigest { + ret = append(ret, d.uapi.UnpackMSB(runningDigest[i])...) + } + return ret +} + +func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 { + panic("TODO") + // we need to do two things here -- first the padding has to be put to the + // right place. For that we need to know how many blocks we have used. We + // need to fit at least 9 more bytes (padding byte and 8 bytes for input + // length). Knowing the block, we have to keep running track if the current + // block is the expected one. + // + // idea - have a mask for blocks where 1 is only for the block we want to + // use. +} + +func (d *digest) Reset() { + d.in = nil +} + +func (d *digest) Size() int { return 32 } diff --git a/std/hash/sha2/sha2_test.go b/std/hash/sha2/sha2_test.go new file mode 100644 index 0000000000..d4acf5baf3 --- /dev/null +++ b/std/hash/sha2/sha2_test.go @@ -0,0 +1,50 @@ +package sha2 + +import ( + "crypto/sha256" + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +type sha2Circuit struct { + In []uints.U8 + Expected [32]uints.U8 +} + +func (c *sha2Circuit) Define(api frontend.API) error { + h, err := New(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + h.Write(c.In) + res := h.Sum() + if len(res) != 32 { + return fmt.Errorf("not 32 bytes") + } + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestSHA2(t *testing.T) { + bts := make([]byte, 310) + dgst := sha256.Sum256(bts) + witness := sha2Circuit{ + In: uints.NewU8Array(bts), + } + copy(witness.Expected[:], uints.NewU8Array(dgst[:])) + err := test.IsSolved(&sha2Circuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } +} diff --git a/std/hints.go b/std/hints.go index 8fc1222781..d807b1f024 100644 --- a/std/hints.go +++ b/std/hints.go @@ -3,11 +3,16 @@ package std import ( "sync" - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/std/algebra/sw_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls24315" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls24315" + "github.com/consensys/gnark/std/evmprecompiles" + "github.com/consensys/gnark/std/internal/logderivarg" "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/bitslice" "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/rangecheck" + "github.com/consensys/gnark/std/selector" ) var registerOnce sync.Once @@ -21,14 +26,16 @@ func RegisterHints() { } func registerHints() { - // note that importing these packages may already trigger a call to hint.Register(...) - hint.Register(sw_bls24315.DecomposeScalarG1) - hint.Register(sw_bls12377.DecomposeScalarG1) - hint.Register(sw_bls24315.DecomposeScalarG2) - hint.Register(sw_bls12377.DecomposeScalarG2) - hint.Register(bits.NTrits) - hint.Register(bits.NNAF) - hint.Register(bits.IthBit) - hint.Register(bits.NBits) - hint.Register(emulated.GetHints()...) + // note that importing these packages may already trigger a call to solver.RegisterHint(...) + solver.RegisterHint(sw_bls24315.DecomposeScalarG1) + solver.RegisterHint(sw_bls12377.DecomposeScalarG1) + solver.RegisterHint(sw_bls24315.DecomposeScalarG2) + solver.RegisterHint(sw_bls12377.DecomposeScalarG2) + solver.RegisterHint(bits.GetHints()...) + solver.RegisterHint(selector.GetHints()...) + solver.RegisterHint(emulated.GetHints()...) + solver.RegisterHint(rangecheck.GetHints()...) + solver.RegisterHint(evmprecompiles.GetHints()...) + solver.RegisterHint(logderivarg.GetHints()...) + solver.RegisterHint(bitslice.GetHints()...) } diff --git a/std/hints_test.go b/std/hints_test.go index f88779b231..64bf7b362d 100644 --- a/std/hints_test.go +++ b/std/hints_test.go @@ -10,7 +10,7 @@ func ExampleRegisterHints() { var ccs constraint.ConstraintSystem // since package bits is not imported, the hint NNAF is not registered - // --> hint.Register(bits.NNAF) + // --> solver.RegisterHint(bits.NNAF) // rather than to keep track on which hints are needed, a prover/solver service can register all // gnark/std hints with this call RegisterHints() diff --git a/std/internal/logderivarg/logderivarg.go b/std/internal/logderivarg/logderivarg.go new file mode 100644 index 0000000000..d0625ea993 --- /dev/null +++ b/std/internal/logderivarg/logderivarg.go @@ -0,0 +1,220 @@ +// Package logderivarg implements log-derivative argument. +// +// The log-derivative argument was described in [Haböck22] as an improvement +// over [BCG+18]. In [BCG+18], it was shown that to show inclusion of a multiset +// S in T, one can show +// +// ∏_{f∈F} (x-f)^count(f, S) == ∏_{s∈S} x-s, +// +// where function `count` counts the number of occurences of f in S. The problem +// with this approach is the high cost for exponentiating the left-hand side of +// the equation. However, in [Haböck22] it was shown that when avoiding the +// poles, we can perform the same check for the log-derivative variant of the +// equation: +// +// ∑_{f∈F} count(f,S)/(x-f) == ∑_{s∈S} 1/(x-s). +// +// Additionally, when the entries of both S and T are vectors, then instead we +// can check random linear combinations. So, when F is a matrix and S is a +// multiset of its rows, we first generate random linear coefficients (r_1, ..., +// r_n) and check +// +// ∑_{f∈F} count(f,S)/(x-∑_{i∈[n]}r_i*f_i) == ∑_{s∈S} 1/(x-∑_{i∈[n]}r_i*s_i). +// +// This package is a low-level primitive for building more extensive gadgets. It +// only checks the last equation, but the tables and queries should be built by +// the users. +// +// NB! The package doesn't check that the entries in table F are unique. +// +// [BCG+18]: https://eprint.iacr.org/2018/380 +// [Haböck22]: https://eprint.iacr.org/2022/1530 +package logderivarg + +// TODO: we handle both constant and variable tables. But for variable tables we +// have to ensure that all the table entries differ! Right now isn't a problem +// because everywhere we build we also have indices which ensure uniqueness. I +// guess the best approach is to have safe and unsafe versions where the safe +// version performs additional sorting. But that is really really expensive as +// we have to show that all sorted values ara monotonically increasing. + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/std/multicommit" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hints used in this package +func GetHints() []solver.Hint { + return []solver.Hint{countHint} +} + +// Table is a vector of vectors. +type Table [][]frontend.Variable + +// AsTable returns a vector as a single-column table. +func AsTable(vector []frontend.Variable) Table { + ret := make([][]frontend.Variable, len(vector)) + for i := range vector { + ret[i] = []frontend.Variable{vector[i]} + } + return ret +} + +// Build builds the argument using the table and queries. If both table and +// queries are multiple-column, then also samples coefficients for the random +// linear combinations. +func Build(api frontend.API, table Table, queries Table) error { + if len(table) == 0 { + return fmt.Errorf("table empty") + } + nbRow := len(table[0]) + constTable := true + countInputs := []frontend.Variable{len(table), nbRow} + for i := range table { + if len(table[i]) != nbRow { + return fmt.Errorf("table row length mismatch") + } + if constTable { + for j := range table[i] { + if _, isConst := api.Compiler().ConstantValue(table[i][j]); !isConst { + constTable = false + } + } + } + countInputs = append(countInputs, table[i]...) + } + for i := range queries { + if len(queries[i]) != nbRow { + return fmt.Errorf("query row length mismatch") + } + countInputs = append(countInputs, queries[i]...) + } + exps, err := api.NewHint(countHint, len(table), countInputs...) + if err != nil { + return fmt.Errorf("hint: %w", err) + } + + var toCommit []frontend.Variable + if !constTable { + for i := range table { + toCommit = append(toCommit, table[i]...) + } + } + for i := range queries { + toCommit = append(toCommit, queries[i]...) + } + toCommit = append(toCommit, exps...) + + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + rowCoeffs, challenge := randLinearCoefficients(api, nbRow, commitment) + var lp frontend.Variable = 0 + for i := range table { + tmp := api.DivUnchecked(exps[i], api.Sub(challenge, randLinearCombination(api, rowCoeffs, table[i]))) + lp = api.Add(lp, tmp) + } + var rp frontend.Variable = 0 + for i := range queries { + tmp := api.Inverse(api.Sub(challenge, randLinearCombination(api, rowCoeffs, queries[i]))) + rp = api.Add(rp, tmp) + } + api.AssertIsEqual(lp, rp) + return nil + }, toCommit...) + return nil +} + +func randLinearCoefficients(api frontend.API, nbRow int, commitment frontend.Variable) (rowCoeffs []frontend.Variable, challenge frontend.Variable) { + if nbRow == 1 { + return []frontend.Variable{1}, commitment + } + hasher, err := mimc.NewMiMC(api) + if err != nil { + panic(err) + } + rowCoeffs = make([]frontend.Variable, nbRow) + for i := 0; i < nbRow; i++ { + hasher.Reset() + hasher.Write(i+1, commitment) + rowCoeffs[i] = hasher.Sum() + } + return rowCoeffs, commitment +} + +func randLinearCombination(api frontend.API, rowCoeffs []frontend.Variable, row []frontend.Variable) frontend.Variable { + if len(rowCoeffs) != len(row) { + panic("coefficient count mismatch") + } + var res frontend.Variable = 0 + for i := range rowCoeffs { + res = api.Add(res, api.Mul(rowCoeffs[i], row[i])) + } + return res +} + +func countHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) <= 2 { + return fmt.Errorf("at least two input required") + } + if !inputs[0].IsInt64() { + return fmt.Errorf("first element must be length of table") + } + nbTable := int(inputs[0].Int64()) + if !inputs[1].IsInt64() { + return fmt.Errorf("first element must be length of row") + } + nbRow := int(inputs[1].Int64()) + if len(inputs) < 2+nbTable { + return fmt.Errorf("input doesn't fit table") + } + if len(outputs) != nbTable { + return fmt.Errorf("output not table size") + } + if (len(inputs)-2-nbTable*nbRow)%nbRow != 0 { + return fmt.Errorf("query count not full integer") + } + nbQueries := (len(inputs) - 2 - nbTable*nbRow) / nbRow + if nbQueries <= 0 { + return fmt.Errorf("at least one query required") + } + nbBytes := (m.BitLen() + 7) / 8 + buf := make([]byte, nbBytes*nbRow) + histo := make(map[string]int64, nbTable) // string key as big ints not comparable + for i := 0; i < nbTable; i++ { + for j := 0; j < nbRow; j++ { + inputs[2+nbRow*i+j].FillBytes(buf[j*nbBytes : (j+1)*nbBytes]) + } + k := string(buf) + if _, ok := histo[k]; ok { + return fmt.Errorf("duplicate key") + } + histo[k] = 0 + } + for i := 0; i < nbQueries; i++ { + for j := 0; j < nbRow; j++ { + inputs[2+nbRow*nbTable+nbRow*i+j].FillBytes(buf[j*nbBytes : (j+1)*nbBytes]) + } + k := string(buf) + v, ok := histo[k] + if !ok { + return fmt.Errorf("query element not in table") + } + v++ + histo[k] = v + } + for i := 0; i < nbTable; i++ { + for j := 0; j < nbRow; j++ { + inputs[2+nbRow*i+j].FillBytes(buf[j*nbBytes : (j+1)*nbBytes]) + } + outputs[i].Set(big.NewInt(histo[string(buf)])) + } + return nil +} diff --git a/std/internal/logderivprecomp/logderivprecomp.go b/std/internal/logderivprecomp/logderivprecomp.go new file mode 100644 index 0000000000..d768c10a32 --- /dev/null +++ b/std/internal/logderivprecomp/logderivprecomp.go @@ -0,0 +1,127 @@ +// Package logderivprecomp allows computing functions using precomputation. +// +// Instead of computing binary functions and checking that the result is +// correctly constrained, we instead can precompute all valid values of a +// function and then perform lookup to obtain the result. For example, for the +// XOR function we would naively otherwise have to split the inputs into bits, +// XOR one-by-one and recombine. +// +// With this package, we can instead compute all results for two inputs of +// length 8 bit and then just perform a lookup on the inputs. +// +// We use the [logderivarg] package for the actual log-derivative argument. +package logderivprecomp + +import ( + "fmt" + "math/big" + "reflect" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/internal/logderivarg" +) + +type ctxPrecomputedKey struct{ fn uintptr } + +// Precomputed holds all precomputed function values and queries. +type Precomputed struct { + api frontend.API + compute solver.Hint + queries []frontend.Variable + rets []uint +} + +// New returns a new [Precomputed]. It defers the log-derivative argument. +func New(api frontend.API, fn solver.Hint, rets []uint) (*Precomputed, error) { + kv, ok := api.Compiler().(kvstore.Store) + if !ok { + panic("builder should implement key-value store") + } + ch := kv.GetKeyValue(ctxPrecomputedKey{fn: reflect.ValueOf(fn).Pointer()}) + if ch != nil { + if prt, ok := ch.(*Precomputed); ok { + return prt, nil + } else { + panic("stored rangechecker is not valid") + } + } + // check that the output lengths fit into a single element + var s uint = 16 + for _, v := range rets { + s += v + } + if s >= uint(api.Compiler().FieldBitLen()) { + return nil, fmt.Errorf("result doesn't fit into field element") + } + t := &Precomputed{ + api: api, + compute: fn, + queries: nil, + rets: rets, + } + kv.SetKeyValue(ctxPrecomputedKey{fn: reflect.ValueOf(fn).Pointer()}, t) + api.Compiler().Defer(t.build) + return t, nil +} + +func (t *Precomputed) pack(x, y frontend.Variable, rets []frontend.Variable) frontend.Variable { + shift := big.NewInt(1 << 8) + packed := t.api.Add(x, t.api.Mul(y, shift)) + for i := range t.rets { + shift.Lsh(shift, t.rets[i]) + packed = t.api.Add(packed, t.api.Mul(rets[i], shift)) + } + return packed +} + +// Query +func (t *Precomputed) Query(x, y frontend.Variable) []frontend.Variable { + // we don't have to check here. We assume the inputs are range checked and + // range check the output. + rets, err := t.api.Compiler().NewHint(t.compute, len(t.rets), x, y) + if err != nil { + panic(err) + } + packed := t.pack(x, y, rets) + t.queries = append(t.queries, packed) + return rets +} + +func (t *Precomputed) buildTable() []frontend.Variable { + tmp := new(big.Int) + shift := new(big.Int) + tbl := make([]frontend.Variable, 65536) + inputs := []*big.Int{big.NewInt(0), big.NewInt(0)} + outputs := make([]*big.Int, len(t.rets)) + for i := range outputs { + outputs[i] = new(big.Int) + } + for x := int64(0); x < 256; x++ { + inputs[0].SetInt64(x) + for y := int64(0); y < 256; y++ { + shift.SetInt64(1 << 8) + i := x | (y << 8) + inputs[1].SetInt64(y) + if err := t.compute(t.api.Compiler().Field(), inputs, outputs); err != nil { + panic(err) + } + tblval := new(big.Int).SetInt64(i) + for j := range t.rets { + shift.Lsh(shift, t.rets[j]) + tblval.Add(tblval, tmp.Mul(outputs[j], shift)) + } + tbl[i] = tblval + } + } + return tbl +} + +func (t *Precomputed) build(api frontend.API) error { + if len(t.queries) == 0 { + return nil + } + table := t.buildTable() + return logderivarg.Build(t.api, logderivarg.AsTable(table), logderivarg.AsTable(t.queries)) +} diff --git a/std/internal/logderivprecomp/logderivprecomp_test.go b/std/internal/logderivprecomp/logderivprecomp_test.go new file mode 100644 index 0000000000..8de2755f1a --- /dev/null +++ b/std/internal/logderivprecomp/logderivprecomp_test.go @@ -0,0 +1,55 @@ +package logderivprecomp + +import ( + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type TestXORCircuit struct { + X, Y [100]frontend.Variable + Res [100]frontend.Variable +} + +func (c *TestXORCircuit) Define(api frontend.API) error { + tbl, err := New(api, xorHint, []uint{8}) + if err != nil { + return err + } + for i := range c.X { + res := tbl.Query(c.X[i], c.Y[i]) + api.AssertIsEqual(res[0], c.Res[i]) + } + return nil +} + +func xorHint(_ *big.Int, inputs, outputs []*big.Int) error { + outputs[0].Xor(inputs[0], inputs[1]) + return nil +} + +func TestXor(t *testing.T) { + assert := test.NewAssert(t) + bound := big.NewInt(255) + var xs, ys, ress [100]frontend.Variable + for i := range xs { + x, _ := rand.Int(rand.Reader, bound) + y, _ := rand.Int(rand.Reader, bound) + ress[i] = new(big.Int).Xor(x, y) + xs[i] = x + ys[i] = y + } + witness := &TestXORCircuit{X: xs, Y: ys, Res: ress} + assert.ProverSucceeded(&TestXORCircuit{}, witness, + test.WithBackends(backend.GROTH16), + test.WithSolverOpts(solver.WithHints(xorHint)), + test.NoFuzzing(), + test.NoSerialization(), + test.WithCurves(ecc.BN254)) +} diff --git a/std/lookup/logderivlookup/doc_test.go b/std/lookup/logderivlookup/doc_test.go new file mode 100644 index 0000000000..e8ab1a01e4 --- /dev/null +++ b/std/lookup/logderivlookup/doc_test.go @@ -0,0 +1,90 @@ +package logderivlookup_test + +import ( + "crypto/rand" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/lookup/logderivlookup" +) + +type LookupCircuit struct { + Entries [1000]frontend.Variable + Queries, Expected [100]frontend.Variable +} + +func (c *LookupCircuit) Define(api frontend.API) error { + t := logderivlookup.New(api) + for i := range c.Entries { + t.Insert(c.Entries[i]) + } + results := t.Lookup(c.Queries[:]...) + if len(results) != len(c.Expected) { + return fmt.Errorf("length mismatch") + } + for i := range results { + api.AssertIsEqual(results[i], c.Expected[i]) + } + return nil +} + +func Example() { + field := ecc.BN254.ScalarField() + witness := LookupCircuit{} + bound := big.NewInt(int64(len(witness.Entries))) + for i := range witness.Entries { + witness.Entries[i], _ = rand.Int(rand.Reader, field) + } + for i := range witness.Queries { + q, _ := rand.Int(rand.Reader, bound) + witness.Queries[i] = q + witness.Expected[i] = new(big.Int).Set(witness.Entries[q.Int64()].(*big.Int)) + } + ccs, err := frontend.Compile(field, r1cs.NewBuilder, &LookupCircuit{}) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: + // compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/lookup/logderivlookup/logderivlookup.go b/std/lookup/logderivlookup/logderivlookup.go new file mode 100644 index 0000000000..85d3f8ea04 --- /dev/null +++ b/std/lookup/logderivlookup/logderivlookup.go @@ -0,0 +1,152 @@ +// Package logderiv implements append-only lookups using log-derivative +// argument. +// +// The lookup is based on log-derivative argument as described in [logderivarg]. +// The lookup table is a matrix where first column is the index and the second +// column the stored values: +// +// 1 x_1 +// 2 x_2 +// ... +// n x_n +// +// When performing a query for index i, the prover returns x_i from memory and +// stores (i, x_i) as a query. During the log-derivative argument building we +// check that all queried tuples (i, x_i) are included in the table. +// +// The complexity of the lookups is linear in the size of the table and the +// number of queries (O(n+m)). +package logderivlookup + +import ( + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/internal/logderivarg" +) + +// Table holds all the entries and queries. +type Table struct { + api frontend.API + + entries []frontend.Variable + immutable bool + results []result + + // each table has a unique blueprint + // the blueprint stores the lookup table entries once + // such that each query only need to store the indexes to lookup + bID constraint.BlueprintID + blueprint constraint.BlueprintLookupHint +} + +type result struct { + ind frontend.Variable + val frontend.Variable +} + +// New returns a new [*Table]. It additionally defers building the +// log-derivative argument. +func New(api frontend.API) *Table { + t := &Table{api: api} + api.Compiler().Defer(t.commit) + + // each table has a unique blueprint + t.bID = api.Compiler().AddBlueprint(&t.blueprint) + return t +} + +// Insert inserts variable val into the lookup table and returns its index as a +// constant. It panics if the table is already committed. +func (t *Table) Insert(val frontend.Variable) (index int) { + if t.immutable { + panic("inserting into committed lookup table") + } + t.entries = append(t.entries, val) + + // each time we insert a new entry, we update the blueprint + v := t.api.Compiler().ToCanonicalVariable(val) + v.Compress(&t.blueprint.EntriesCalldata) + + return len(t.entries) - 1 +} + +// Lookup lookups up values from the lookup tables given by the indices inds. It +// returns a variable for every index. It panics during compile time when +// looking up from a committed or empty table. It panics during solving time +// when the index is out of bounds. +func (t *Table) Lookup(inds ...frontend.Variable) (vals []frontend.Variable) { + if t.immutable { + panic("looking up from a committed lookup table") + } + if len(inds) == 0 { + return nil + } + if len(t.entries) == 0 { + panic("looking up from empty table") + } + return t.performLookup(inds) +} + +// performLookup performs the lookup and returns the resulting variables. +// underneath, it does use the blueprint to encode the lookup hint. +func (t *Table) performLookup(inds []frontend.Variable) []frontend.Variable { + // to build the instruction, we need to first encode its dependency as a calldata []uint32 slice. + // * calldata[0] is the length of the calldata, + // * calldata[1] is the number of entries in the table we consider. + // * calldata[2] is the number of queries (which is the number of indices we are looking up and the number of outputs we expect) + compiler := t.api.Compiler() + + calldata := make([]uint32, 3, 3+len(inds)*2+2) + calldata[1] = uint32(len(t.entries)) + calldata[2] = uint32(len(inds)) + + // encode inputs + for _, in := range inds { + v := compiler.ToCanonicalVariable(in) + v.Compress(&calldata) + } + + // by convention, first calldata is len of inputs + calldata[0] = uint32(len(calldata)) + + // now what we are left to do is add an instruction to the constraint system + // such that at solving time the blueprint can properly execute the lookup logic. + outputs := compiler.AddInstruction(t.bID, calldata) + + // sanity check + if len(outputs) != len(inds) { + panic("sanity check") + } + + // we need to return the variables corresponding to the outputs + internalVariables := make([]frontend.Variable, len(inds)) + lookupResult := make([]result, len(inds)) + + // we need to store the result of the lookup in the table + for i := range inds { + internalVariables[i] = compiler.InternalVariable(outputs[i]) + lookupResult[i] = result{ind: inds[i], val: internalVariables[i]} + } + t.results = append(t.results, lookupResult...) + return internalVariables +} + +func (t *Table) entryTable() [][]frontend.Variable { + tbl := make([][]frontend.Variable, len(t.entries)) + for i := range t.entries { + tbl[i] = []frontend.Variable{i, t.entries[i]} + } + return tbl +} + +func (t *Table) resultsTable() [][]frontend.Variable { + tbl := make([][]frontend.Variable, len(t.results)) + for i := range t.results { + tbl[i] = []frontend.Variable{t.results[i].ind, t.results[i].val} + } + return tbl +} + +func (t *Table) commit(api frontend.API) error { + return logderivarg.Build(api, t.entryTable(), t.resultsTable()) +} diff --git a/std/lookup/logderivlookup/logderivlookup_test.go b/std/lookup/logderivlookup/logderivlookup_test.go new file mode 100644 index 0000000000..0070f18352 --- /dev/null +++ b/std/lookup/logderivlookup/logderivlookup_test.go @@ -0,0 +1,65 @@ +package logderivlookup + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +type LookupCircuit struct { + Entries [1000]frontend.Variable + Queries, Expected [100]frontend.Variable +} + +func (c *LookupCircuit) Define(api frontend.API) error { + t := New(api) + for i := range c.Entries { + t.Insert(c.Entries[i]) + } + results := t.Lookup(c.Queries[:]...) + if len(results) != len(c.Expected) { + return fmt.Errorf("length mismatch") + } + for i := range results { + api.AssertIsEqual(results[i], c.Expected[i]) + } + return nil +} + +func TestLookup(t *testing.T) { + assert := test.NewAssert(t) + field := ecc.BN254.ScalarField() + witness := LookupCircuit{} + bound := big.NewInt(int64(len(witness.Entries))) + for i := range witness.Entries { + witness.Entries[i], _ = rand.Int(rand.Reader, field) + } + for i := range witness.Queries { + q, _ := rand.Int(rand.Reader, bound) + witness.Queries[i] = q + witness.Expected[i] = new(big.Int).Set(witness.Entries[q.Int64()].(*big.Int)) + } + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &LookupCircuit{}) + assert.NoError(err) + + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + assert.NoError(err) + + _, err = ccs.Solve(w) + assert.NoError(err) + + err = test.IsSolved(&LookupCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + + assert.ProverSucceeded(&LookupCircuit{}, &witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16, backend.PLONK)) +} diff --git a/std/math/bits/conversion.go b/std/math/bits/conversion.go index 20040b82d8..ca82c2773e 100644 --- a/std/math/bits/conversion.go +++ b/std/math/bits/conversion.go @@ -56,10 +56,10 @@ type baseConversionConfig struct { // BaseConversionOption configures the behaviour of scalar decomposition. type BaseConversionOption func(opt *baseConversionConfig) error -// WithNbDigits set the resulting number of digits to be used in the base conversion. -// nbDigits must be > 0. If nbDigits is lower than the length of full -// decomposition, then nbDigits least significant digits are returned. If the -// option is not set, then the full decomposition is returned. +// WithNbDigits sets the resulting number of digits (nbDigits) to be used in the base conversion. +// nbDigits must be > 0. If nbDigits is lower than the length of full decomposition and +// WithUnconstrainedOutputs option is not used, then this function generates an unsatisfiable +// constraint. If WithNbDigits option is not set, then the full decomposition is returned. func WithNbDigits(nbDigits int) BaseConversionOption { return func(opt *baseConversionConfig) error { if nbDigits <= 0 { diff --git a/std/math/bits/conversion_binary.go b/std/math/bits/conversion_binary.go index 5d93af06c6..693f477b3f 100644 --- a/std/math/bits/conversion_binary.go +++ b/std/math/bits/conversion_binary.go @@ -3,16 +3,9 @@ package bits import ( "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) -func init() { - // register hints - hint.Register(IthBit) - hint.Register(NBits) -} - // ToBinary is an alias of ToBase(api, Binary, v, opts) func ToBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) []frontend.Variable { return ToBase(api, Binary, v, opts...) @@ -63,18 +56,16 @@ func toBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOptio } } - // if a is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { - bits := make([]frontend.Variable, cfg.NbDigits) - for i := 0; i < len(bits); i++ { - bits[i] = c.Bit(i) - } - return bits + // when cfg.NbDigits == 1, v itself has to be a binary digit. This if clause + // saves one constraint. + if cfg.NbDigits == 1 { + api.AssertIsBoolean(v) + return []frontend.Variable{v} } c := big.NewInt(1) - bits, err := api.Compiler().NewHint(NBits, cfg.NbDigits, v) + bits, err := api.Compiler().NewHint(nBits, cfg.NbDigits, v) if err != nil { panic(err) } @@ -94,27 +85,3 @@ func toBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOptio return bits } - -// IthBit returns the i-tb bit the input. The function expects exactly two -// integer inputs i and n, takes the little-endian bit representation of n and -// returns its i-th bit. -func IthBit(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - result := results[0] - if !inputs[1].IsUint64() { - result.SetUint64(0) - return nil - } - - result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64())))) - return nil -} - -// NBits returns the first bits of the input. The number of returned bits is -// defined by the length of the results slice. -func NBits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - for i := 0; i < len(results); i++ { - results[i].SetUint64(uint64(n.Bit(i))) - } - return nil -} diff --git a/std/math/bits/conversion_ternary.go b/std/math/bits/conversion_ternary.go index 302868361b..d38095d2b2 100644 --- a/std/math/bits/conversion_ternary.go +++ b/std/math/bits/conversion_ternary.go @@ -4,18 +4,9 @@ import ( "math" "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) -// NTrits returns the first trits of the input. The number of returned trits is -// defined by the length of the results slice. -var NTrits = nTrits - -func init() { - hint.Register(NTrits) -} - // ToTernary is an alias of ToBase(api, Ternary, v, opts...) func ToTernary(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) []frontend.Variable { return ToBase(api, Ternary, v, opts...) @@ -67,26 +58,10 @@ func toTernary(api frontend.API, v frontend.Variable, opts ...BaseConversionOpti } } - // if a is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { - trits := make([]frontend.Variable, cfg.NbDigits) - // TODO using big.Int Text is likely not cheap - base3 := c.Text(3) - i := 0 - for j := len(base3) - 1; j >= 0 && i < len(trits); j-- { - trits[i] = int(base3[j] - 48) - i++ - } - for ; i < len(trits); i++ { - trits[i] = 0 - } - return trits - } - c := big.NewInt(1) b := big.NewInt(3) - trits, err := api.Compiler().NewHint(NTrits, cfg.NbDigits, v) + trits, err := api.Compiler().NewHint(nTrits, cfg.NbDigits, v) if err != nil { panic(err) } @@ -121,19 +96,3 @@ func AssertIsTrit(api frontend.API, v frontend.Variable) { y := api.Mul(api.Sub(1, v), api.Sub(2, v)) api.AssertIsEqual(api.Mul(v, y), 0) } - -func nTrits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - // TODO using big.Int Text method is likely not cheap - base3 := n.Text(3) - i := 0 - for j := len(base3) - 1; j >= 0 && i < len(results); j-- { - results[i].SetUint64(uint64(base3[j] - 48)) - i++ - } - for ; i < len(results); i++ { - results[i].SetUint64(0) - } - - return nil -} diff --git a/std/math/bits/conversion_test.go b/std/math/bits/conversion_test.go index a43ae6489d..47455b236d 100644 --- a/std/math/bits/conversion_test.go +++ b/std/math/bits/conversion_test.go @@ -31,6 +31,13 @@ func (c *toBinaryCircuit) Define(api frontend.API) error { func TestToBinary(t *testing.T) { assert := test.NewAssert(t) assert.ProverSucceeded(&toBinaryCircuit{}, &toBinaryCircuit{A: 5, B0: 1, B1: 0, B2: 1}) + + assert.ProverSucceeded(&toBinaryCircuit{}, &toBinaryCircuit{A: 3, B0: 1, B1: 1, B2: 0}) + + // prover fails when the binary representation of A has more than 3 bits + assert.ProverFailed(&toBinaryCircuit{}, &toBinaryCircuit{A: 8, B0: 0, B1: 0, B2: 0}) + + assert.ProverFailed(&toBinaryCircuit{}, &toBinaryCircuit{A: 10, B0: 0, B1: 1, B2: 0}) } type toTernaryCircuit struct { diff --git a/std/math/bits/hints.go b/std/math/bits/hints.go new file mode 100644 index 0000000000..bb3da6d13c --- /dev/null +++ b/std/math/bits/hints.go @@ -0,0 +1,114 @@ +package bits + +import ( + "errors" + "math/big" + + "github.com/consensys/gnark/constraint/solver" +) + +func GetHints() []solver.Hint { + return []solver.Hint{ + ithBit, + nBits, + nTrits, + nNaf, + } +} + +func init() { + solver.RegisterHint(GetHints()...) +} + +// IthBit returns the i-tb bit the input. The function expects exactly two +// integer inputs i and n, takes the little-endian bit representation of n and +// returns its i-th bit. +func ithBit(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + result := results[0] + if !inputs[1].IsUint64() { + result.SetUint64(0) + return nil + } + + result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64())))) + return nil +} + +// NBits returns the first bits of the input. The number of returned bits is +// defined by the length of the results slice. +func nBits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + for i := 0; i < len(results); i++ { + results[i].SetUint64(uint64(n.Bit(i))) + } + return nil +} + +// nTrits returns the first trits of the input. The number of returned trits is +// defined by the length of the results slice. +func nTrits(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + // TODO using big.Int Text method is likely not cheap + base3 := n.Text(3) + i := 0 + for j := len(base3) - 1; j >= 0 && i < len(results); j-- { + results[i].SetUint64(uint64(base3[j] - 48)) + i++ + } + for ; i < len(results); i++ { + results[i].SetUint64(0) + } + + return nil +} + +// NNAF returns the NAF decomposition of the input. The number of digits is +// defined by the number of elements in the results slice. +func nNaf(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + return nafDecomposition(n, results) +} + +// nafDecomposition gets the naf decomposition of a big number +func nafDecomposition(a *big.Int, results []*big.Int) error { + if a == nil || a.Sign() == -1 { + return errors.New("invalid input to naf decomposition; negative (or nil) big.Int not supported") + } + + var zero, one, three big.Int + + one.SetUint64(1) + three.SetUint64(3) + + n := 0 + + // some buffers + var buf, aCopy big.Int + aCopy.Set(a) + + for aCopy.Cmp(&zero) != 0 && n < len(results) { + + // if aCopy % 2 == 0 + buf.And(&aCopy, &one) + + // aCopy even + if buf.Cmp(&zero) == 0 { + results[n].SetUint64(0) + } else { // aCopy odd + buf.And(&aCopy, &three) + if buf.IsUint64() && buf.Uint64() == 3 { + results[n].SetInt64(-1) + aCopy.Add(&aCopy, &one) + } else { + results[n].SetUint64(1) + } + } + aCopy.Rsh(&aCopy, 1) + n++ + } + for ; n < len(results); n++ { + results[n].SetUint64(0) + } + + return nil +} diff --git a/std/math/bits/naf.go b/std/math/bits/naf.go index b8e21ed965..56b4b9e468 100644 --- a/std/math/bits/naf.go +++ b/std/math/bits/naf.go @@ -1,21 +1,11 @@ package bits import ( - "errors" "math/big" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) -// NNAF returns the NAF decomposition of the input. The number of digits is -// defined by the number of elements in the results slice. -var NNAF = nNaf - -func init() { - hint.Register(NNAF) -} - // ToNAF returns the NAF decomposition of given input. // The non-adjacent form (NAF) of a number is a unique signed-digit representation, // in which non-zero values cannot be adjacent. For example, NAF(13) = [1, 0, -1, 0, 1]. @@ -32,25 +22,9 @@ func ToNAF(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) } } - // if v is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { - bits := make([]*big.Int, cfg.NbDigits) - for i := 0; i < len(bits); i++ { - bits[i] = big.NewInt(0) - } - if err := nafDecomposition(c, bits); err != nil { - panic(err) - } - res := make([]frontend.Variable, len(bits)) - for i := 0; i < len(bits); i++ { - res[i] = bits[i] - } - return res - } - c := big.NewInt(1) - bits, err := api.Compiler().NewHint(NNAF, cfg.NbDigits, v) + bits, err := api.Compiler().NewHint(nNaf, cfg.NbDigits, v) if err != nil { panic(err) } @@ -74,52 +48,3 @@ func ToNAF(api frontend.API, v frontend.Variable, opts ...BaseConversionOption) return bits } - -func nNaf(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - return nafDecomposition(n, results) -} - -// nafDecomposition gets the naf decomposition of a big number -func nafDecomposition(a *big.Int, results []*big.Int) error { - if a == nil || a.Sign() == -1 { - return errors.New("invalid input to naf decomposition; negative (or nil) big.Int not supported") - } - - var zero, one, three big.Int - - one.SetUint64(1) - three.SetUint64(3) - - n := 0 - - // some buffers - var buf, aCopy big.Int - aCopy.Set(a) - - for aCopy.Cmp(&zero) != 0 && n < len(results) { - - // if aCopy % 2 == 0 - buf.And(&aCopy, &one) - - // aCopy even - if buf.Cmp(&zero) == 0 { - results[n].SetUint64(0) - } else { // aCopy odd - buf.And(&aCopy, &three) - if buf.IsUint64() && buf.Uint64() == 3 { - results[n].SetInt64(-1) - aCopy.Add(&aCopy, &one) - } else { - results[n].SetUint64(1) - } - } - aCopy.Rsh(&aCopy, 1) - n++ - } - for ; n < len(results); n++ { - results[n].SetUint64(0) - } - - return nil -} diff --git a/std/math/bitslice/hints.go b/std/math/bitslice/hints.go new file mode 100644 index 0000000000..d9797c6a9a --- /dev/null +++ b/std/math/bitslice/hints.go @@ -0,0 +1,34 @@ +package bitslice + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +func GetHints() []solver.Hint { + return []solver.Hint{ + partitionHint, + } +} + +func partitionHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two inputs") + } + if len(outputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + if !inputs[0].IsUint64() { + return fmt.Errorf("split location must be int") + } + split := uint(inputs[0].Uint64()) + div := new(big.Int).Lsh(big.NewInt(1), split) + outputs[0].QuoRem(inputs[1], div, outputs[1]) + return nil +} diff --git a/std/math/bitslice/opts.go b/std/math/bitslice/opts.go new file mode 100644 index 0000000000..fa333efb1a --- /dev/null +++ b/std/math/bitslice/opts.go @@ -0,0 +1,37 @@ +package bitslice + +import "fmt" + +type opt struct { + digits int + nocheck bool +} + +func parseOpts(opts ...Option) (*opt, error) { + o := new(opt) + for _, apply := range opts { + if err := apply(o); err != nil { + return nil, err + } + } + return o, nil +} + +type Option func(*opt) error + +func WithNbDigits(nbDigits int) Option { + return func(o *opt) error { + if nbDigits < 1 { + return fmt.Errorf("given number of digits %d smaller than 1", nbDigits) + } + o.digits = nbDigits + return nil + } +} + +func WithUnconstrainedOutputs() Option { + return func(o *opt) error { + o.nocheck = true + return nil + } +} diff --git a/std/math/bitslice/partition.go b/std/math/bitslice/partition.go new file mode 100644 index 0000000000..f20c6d3513 --- /dev/null +++ b/std/math/bitslice/partition.go @@ -0,0 +1,69 @@ +package bitslice + +import ( + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/rangecheck" +) + +// Partition partitions v into two parts splitted at bit numbered split. The +// following holds +// +// v = lower + 2^split * upper. +// +// The method enforces that lower < 2^split and upper < 2^split', where +// split'=nbScalar-split. When giving the option [WithNbDigits], we instead use +// the bound split'=nbDigits-split. +func Partition(api frontend.API, v frontend.Variable, split uint, opts ...Option) (lower, upper frontend.Variable) { + opt, err := parseOpts(opts...) + if err != nil { + panic(err) + } + // handle constant case + if vc, ok := api.Compiler().ConstantValue(v); ok { + if opt.digits > 0 && vc.BitLen() > opt.digits { + panic("input larger than bound") + } + if split == 0 { + return 0, vc + } + div := new(big.Int).Lsh(big.NewInt(1), split) + l, u := new(big.Int), new(big.Int) + u.QuoRem(vc, div, l) + return l, u + } + rh := rangecheck.New(api) + if split == 0 { + if opt.digits > 0 { + rh.Check(v, opt.digits) + } + return 0, v + } + ret, err := api.Compiler().NewHint(partitionHint, 2, split, v) + if err != nil { + panic(err) + } + + upper = ret[0] + lower = ret[1] + + if opt.nocheck { + if opt.digits > 0 { + rh.Check(v, opt.digits) + } + return + } + upperBound := api.Compiler().FieldBitLen() + if opt.digits > 0 { + upperBound = opt.digits + } + rh.Check(upper, upperBound) + rh.Check(lower, int(split)) + + m := big.NewInt(1) + m.Lsh(m, split) + composed := api.Add(lower, api.Mul(upper, m)) + api.AssertIsEqual(composed, v) + return +} diff --git a/std/math/bitslice/partition_test.go b/std/math/bitslice/partition_test.go new file mode 100644 index 0000000000..891c01e658 --- /dev/null +++ b/std/math/bitslice/partition_test.go @@ -0,0 +1,30 @@ +package bitslice + +import ( + "testing" + + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type partitionCircuit struct { + Split uint + In, ExpLower, ExpUpper frontend.Variable +} + +func (c *partitionCircuit) Define(api frontend.API) error { + lower, upper := Partition(api, c.In, c.Split) + api.AssertIsEqual(lower, c.ExpLower) + api.AssertIsEqual(upper, c.ExpUpper) + return nil +} + +func TestPartition(t *testing.T) { + assert := test.NewAssert(t) + // TODO: for some reason next fails with PLONK+FRI + assert.ProverSucceeded(&partitionCircuit{Split: 16}, &partitionCircuit{Split: 16, ExpUpper: 0xffff, ExpLower: 0x1234, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&partitionCircuit{Split: 0}, &partitionCircuit{Split: 0, ExpUpper: 0xffff1234, ExpLower: 0, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&partitionCircuit{Split: 32}, &partitionCircuit{Split: 32, ExpUpper: 0, ExpLower: 0xffff1234, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&partitionCircuit{Split: 4}, &partitionCircuit{Split: 4, ExpUpper: 0xffff123, ExpLower: 4, In: 0xffff1234}, test.WithBackends(backend.GROTH16, backend.PLONK)) +} diff --git a/std/math/emulated/doc_example_field_test.go b/std/math/emulated/doc_example_field_test.go index b6f3b8a18e..338c18b4ec 100644 --- a/std/math/emulated/doc_example_field_test.go +++ b/std/math/emulated/doc_example_field_test.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/math/emulated" @@ -61,7 +62,7 @@ func ExampleField() { } else { fmt.Println("setup done") } - proof, err := groth16.Prove(ccs, pk, witnessData, backend.WithHints(emulated.GetHints()...)) + proof, err := groth16.Prove(ccs, pk, witnessData, backend.WithSolverOptions(solver.WithHints(emulated.GetHints()...))) if err != nil { panic(err) } else { diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index aa7d5f8cd7..4f4e3eb4bb 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -87,6 +87,50 @@ func testAssertIsLessEqualThan[T FieldParams](t *testing.T) { }, testName[T]()) } +type AssertIsLessEqualThanConstantCiruit[T FieldParams] struct { + L Element[T] + R *big.Int +} + +func (c *AssertIsLessEqualThanConstantCiruit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + R := f.NewElement(c.R) + f.AssertIsLessOrEqual(&c.L, R) + return nil +} + +func testAssertIsLessEqualThanConstant[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness AssertIsLessEqualThanConstantCiruit[T] + R, _ := rand.Int(rand.Reader, fp.Modulus()) + L, _ := rand.Int(rand.Reader, R) + circuit.R = R + witness.R = R + witness.L = ValueOf[T](L) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) + assert.Run(func(assert *test.Assert) { + var circuit, witness AssertIsLessEqualThanConstantCiruit[T] + R := new(big.Int).Set(fp.Modulus()) + L, _ := rand.Int(rand.Reader, R) + circuit.R = R + witness.R = R + witness.L = ValueOf[T](L) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, fmt.Sprintf("overflow/%s", testName[T]())) +} + +func TestAssertIsLessEqualThanConstant(t *testing.T) { + testAssertIsLessEqualThanConstant[Goldilocks](t) + testAssertIsLessEqualThanConstant[Secp256k1Fp](t) + testAssertIsLessEqualThanConstant[BN254Fp](t) +} + type AddCircuit[T FieldParams] struct { A, B, C Element[T] } @@ -800,3 +844,115 @@ func TestIssue348UnconstrainedLimbs(t *testing.T) { // inputs. assert.Error(err) } + +type AssertInRangeCircuit[T FieldParams] struct { + X Element[T] +} + +func (c *AssertInRangeCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + f.AssertIsInRange(&c.X) + return nil +} + +func TestAssertInRange(t *testing.T) { + testAssertIsInRange[Goldilocks](t) + testAssertIsInRange[Secp256k1Fp](t) + testAssertIsInRange[BN254Fp](t) +} + +func testAssertIsInRange[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + X, _ := rand.Int(rand.Reader, fp.Modulus()) + circuit := AssertInRangeCircuit[T]{} + witness := AssertInRangeCircuit[T]{X: ValueOf[T](X)} + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + witness2 := AssertInRangeCircuit[T]{X: ValueOf[T](0)} + t := 0 + for i := 0; i < int(fp.NbLimbs())-1; i++ { + L := new(big.Int).Lsh(big.NewInt(1), fp.BitsPerLimb()) + L.Sub(L, big.NewInt(1)) + witness2.X.Limbs[i] = L + t += int(fp.BitsPerLimb()) + } + highlimb := fp.Modulus().BitLen() - t + L := new(big.Int).Lsh(big.NewInt(1), uint(highlimb)) + L.Sub(L, big.NewInt(1)) + witness2.X.Limbs[fp.NbLimbs()-1] = L + assert.ProverFailed(&circuit, &witness2, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type IsZeroCircuit[T FieldParams] struct { + X, Y Element[T] + Zero frontend.Variable +} + +func (c *IsZeroCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + R := f.Add(&c.X, &c.Y) + api.AssertIsEqual(c.Zero, f.IsZero(R)) + return nil +} + +func TestIsZero(t *testing.T) { + testIsZero[Goldilocks](t) + testIsZero[Secp256k1Fp](t) + testIsZero[BN254Fp](t) +} + +func testIsZero[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + X, _ := rand.Int(rand.Reader, fp.Modulus()) + Y := new(big.Int).Sub(fp.Modulus(), X) + circuit := IsZeroCircuit[T]{} + assert.ProverSucceeded(&circuit, &IsZeroCircuit[T]{X: ValueOf[T](X), Y: ValueOf[T](Y), Zero: 1}, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&circuit, &IsZeroCircuit[T]{X: ValueOf[T](X), Y: ValueOf[T](0), Zero: 0}, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type SqrtCircuit[T FieldParams] struct { + X, Expected Element[T] +} + +func (c *SqrtCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Sqrt(&c.X) + f.AssertIsEqual(res, &c.Expected) + return nil +} + +func TestSqrt(t *testing.T) { + testSqrt[Goldilocks](t) + testSqrt[Secp256k1Fp](t) + testSqrt[BN254Fp](t) +} + +func testSqrt[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var X *big.Int + exp := new(big.Int) + for { + X, _ = rand.Int(rand.Reader, fp.Modulus()) + if exp.ModSqrt(X, fp.Modulus()) != nil { + break + } + } + assert.ProverSucceeded(&SqrtCircuit[T]{}, &SqrtCircuit[T]{X: ValueOf[T](X), Expected: ValueOf[T](exp)}, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go new file mode 100644 index 0000000000..9b888d6a4d --- /dev/null +++ b/std/math/emulated/emparams/emparams.go @@ -0,0 +1,201 @@ +// Package emparams contains emulation parameters for well known fields. +// +// We define some well-known parameters in this package for compatibility and +// ease of use. When needing to use parameters not defined in this package it is +// sufficient to define a new type implementing [FieldParams]. For example, as: +// +// type SmallField struct {} +// func (SmallField) NbLimbs() uint { return 1 } +// func (SmallField) BitsPerLimb() uint { return 11 } +// func (SmallField) IsPrime() bool { return true } +// func (SmallField) Modulus() *big.Int { return big.NewInt(1032) } +package emparams + +import ( + "crypto/elliptic" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/goldilocks" +) + +type fourLimbPrimeField struct{} + +func (fourLimbPrimeField) NbLimbs() uint { return 4 } +func (fourLimbPrimeField) BitsPerLimb() uint { return 64 } +func (fourLimbPrimeField) IsPrime() bool { return true } + +type sixLimbPrimeField struct{} + +func (sixLimbPrimeField) NbLimbs() uint { return 6 } +func (sixLimbPrimeField) BitsPerLimb() uint { return 64 } +func (sixLimbPrimeField) IsPrime() bool { return true } + +// Goldilocks provides type parametrization for field emulation: +// - limbs: 1 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffff00000001 (base 16) +// 18446744069414584321 (base 10) +type Goldilocks struct{} + +func (fp Goldilocks) NbLimbs() uint { return 1 } +func (fp Goldilocks) BitsPerLimb() uint { return 64 } +func (fp Goldilocks) IsPrime() bool { return true } +func (fp Goldilocks) Modulus() *big.Int { return goldilocks.Modulus() } + +// Secp256k1Fp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f (base 16) +// 115792089237316195423570985008687907853269984665640564039457584007908834671663 (base 10) +// +// This is the base field of the SECP256k1 curve. +type Secp256k1Fp struct{ fourLimbPrimeField } + +func (fp Secp256k1Fp) Modulus() *big.Int { return ecc.SECP256K1.BaseField() } + +// Secp256k1Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 (base 16) +// 115792089237316195423570985008687907852837564279074904382605163141518161494337 (base 10) +// +// This is the scalar field of the SECP256k1 curve. +type Secp256k1Fr struct{ fourLimbPrimeField } + +func (fp Secp256k1Fr) Modulus() *big.Int { return ecc.SECP256K1.ScalarField() } + +// BN254Fp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 (base 16) +// 21888242871839275222246405745257275088696311157297823662689037894645226208583 (base 10) +// +// This is the base field of the BN254 curve. +type BN254Fp struct{ fourLimbPrimeField } + +func (fp BN254Fp) Modulus() *big.Int { return ecc.BN254.BaseField() } + +// BN254Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 (base 16) +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 (base 10) +// +// This is the scalar field of the BN254 curve. +type BN254Fr struct{ fourLimbPrimeField } + +func (fp BN254Fr) Modulus() *big.Int { return ecc.BN254.ScalarField() } + +// BLS12377Fp provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 (base 16) +// 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 (base 10) +// +// This is the base field of the BLS12-377 curve. +type BLS12377Fp struct{ sixLimbPrimeField } + +func (fp BLS12377Fp) Modulus() *big.Int { return ecc.BLS12_377.BaseField() } + +// BLS12381Fp provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab (base 16) +// 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 (base 10) +// +// This is the base field of the BLS12-381 curve. +type BLS12381Fp struct{ sixLimbPrimeField } + +func (fp BLS12381Fp) Modulus() *big.Int { return ecc.BLS12_381.BaseField() } + +// BLS12381Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 (base 16) +// 52435875175126190479447740508185965837690552500527637822603658699938581184513 (base 10) +// +// This is the scalar field of the BLS12-381 curve. +type BLS12381Fr struct{ fourLimbPrimeField } + +func (fp BLS12381Fr) Modulus() *big.Int { return ecc.BLS12_381.ScalarField() } + +// P256Fp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff (base 16) +// 115792089210356248762697446949407573530086143415290314195533631308867097853951 (base 10) +// +// This is the base field of the P-256 (also SECP256r1) curve. +type P256Fp struct{ fourLimbPrimeField } + +func (P256Fp) Modulus() *big.Int { return elliptic.P256().Params().P } + +// P256Fr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551 (base 16) +// 115792089210356248762697446949407573529996955224135760342422259061068512044369 (base 10) +// +// This is the base field of the P-256 (also SECP256r1) curve. +type P256Fr struct{ fourLimbPrimeField } + +func (P256Fr) Modulus() *big.Int { return elliptic.P256().Params().N } + +// P384Fp provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff (base 16) +// 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319 (base 10) +// +// This is the base field of the P-384 (also SECP384r1) curve. +type P384Fp struct{ sixLimbPrimeField } + +func (P384Fp) Modulus() *big.Int { return elliptic.P384().Params().P } + +// P384Fr provides type parametrization for field emulation: +// - limbs: 6 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0xffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973 (base 16) +// 39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643 (base 10) +// +// This is the scalar field of the P-384 (also SECP384r1) curve. +type P384Fr struct{ sixLimbPrimeField } + +func (P384Fr) Modulus() *big.Int { return elliptic.P384().Params().N } diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index e4c0c3d357..124f4e9b37 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -6,7 +6,9 @@ import ( "sync" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/rangecheck" "github.com/rs/zerolog" "golang.org/x/exp/constraints" ) @@ -27,16 +29,19 @@ type Field[T FieldParams] struct { maxOfOnce sync.Once // constants for often used elements n, 0 and 1. Allocated only once - nConstOnce sync.Once - nConst *Element[T] - zeroConstOnce sync.Once - zeroConst *Element[T] - oneConstOnce sync.Once - oneConst *Element[T] + nConstOnce sync.Once + nConst *Element[T] + nprevConstOnce sync.Once + nprevConst *Element[T] + zeroConstOnce sync.Once + zeroConst *Element[T] + oneConstOnce sync.Once + oneConst *Element[T] log zerolog.Logger constrainedLimbs map[uint64]struct{} + checker frontend.Rangechecker } // NewField returns an object to be used in-circuit to perform emulated @@ -52,6 +57,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { api: native, log: logger.Logger(), constrainedLimbs: make(map[uint64]struct{}), + checker: rangecheck.New(native), } // ensure prime is correctly set @@ -89,7 +95,9 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { // NewElement builds a new Element[T] from input v. // - if v is a Element[T] or *Element[T] it clones it // - if v is a constant this is equivalent to calling emulated.ValueOf[T] -// - if this methods interpret v (frontend.Variable or []frontend.Variable) as being the limbs; and constrain the limbs following the parameters of the Field. +// - if this methods interprets v as being the limbs (frontend.Variable or []frontend.Variable), +// it constructs a new Element[T] with v as limbs and constraints the limbs to the parameters +// of the Field[T]. func (f *Field[T]) NewElement(v interface{}) *Element[T] { if e, ok := v.(Element[T]); ok { return e.copy() @@ -101,11 +109,6 @@ func (f *Field[T]) NewElement(v interface{}) *Element[T] { return f.packLimbs([]frontend.Variable{v}, true) } if e, ok := v.([]frontend.Variable); ok { - for _, sv := range e { - if !frontend.IsCanonical(sv) { - panic("[]frontend.Variable that are not canonical (known to the compiler) is not a valid input") - } - } return f.packLimbs(e, true) } c := ValueOf[T](v) @@ -136,6 +139,14 @@ func (f *Field[T]) Modulus() *Element[T] { return f.nConst } +// modulusPrev returns modulus-1 as a constant. +func (f *Field[T]) modulusPrev() *Element[T] { + f.nprevConstOnce.Do(func() { + f.nprevConst = newConstElement[T](new(big.Int).Sub(f.fParams.Modulus(), big.NewInt(1))) + }) + return f.nprevConst +} + // packLimbs returns an element from the given limbs. // If strict is true, the most significant limb will be constrained to have width of the most // significant limb of the modulus, which may have less bits than the other limbs. In which case, @@ -157,14 +168,27 @@ func (f *Field[T]) enforceWidthConditional(a *Element[T]) (didConstrain bool) { return false } if _, isConst := f.constantValue(a); isConst { + // enforce constant element limbs not to be large. + for i := range a.Limbs { + val := utils.FromInterface(a.Limbs[i]) + if val.BitLen() > int(f.fParams.BitsPerLimb()) { + panic("constant element limb wider than emulated parameter") + } + } // constant values are constant return false } for i := range a.Limbs { if !frontend.IsCanonical(a.Limbs[i]) { - // this is not a variable. This may happen when some limbs are - // constant and some variables. A strange case but lets try to cover - // it anyway. + // this is not a canonical variable, nor a constant. This may happen + // when some limbs are constant and some variables. Or if we are + // running in a test engine. In either case, we must check that if + // this limb is a [*big.Int] that its bitwidth is less than the + // NbBits. + val := utils.FromInterface(a.Limbs[i]) + if val.BitLen() > int(f.fParams.BitsPerLimb()) { + panic("non-canonical integer limb wider than emulated parameter") + } continue } if vv, ok := a.Limbs[i].(interface{ HashCode() uint64 }); ok { @@ -259,7 +283,7 @@ func (f *Field[T]) compactLimbs(e *Element[T], groupSize, bitsPerLimb uint) []fr // then the limbs may overflow the native field. func (f *Field[T]) maxOverflow() uint { f.maxOfOnce.Do(func() { - f.maxOf = uint(f.api.Compiler().FieldBitLen()-1) - f.fParams.BitsPerLimb() + f.maxOf = uint(f.api.Compiler().FieldBitLen()-2) - f.fParams.BitsPerLimb() }) return f.maxOf } diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index c8daae2bf5..2a02715900 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -5,13 +5,12 @@ import ( "math/big" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/bits" ) // assertLimbsEqualitySlow is the main routine in the package. It asserts that the // two slices of limbs represent the same integer value. This is also the most // costly operation in the package as it does bit decomposition of the limbs. -func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { +func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { nbLimbs := max(len(l), len(r)) maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) @@ -33,52 +32,29 @@ func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, // carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1] // we know that diff[:nbBits] are 0 bits, but still need to constrain them. // to do both; we do a "clean" right shift and only need to boolean constrain the carry part - carry = rsh(api, diff, int(nbBits), int(nbBits+nbCarryBits+1)) + carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1)) } api.AssertIsEqual(carry, maxValueShift) } -// rsh right shifts a variable endDigit-startDigit bits and returns it. -func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) frontend.Variable { +func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable { // if v is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { + if c, ok := f.api.Compiler().ConstantValue(v); ok { bits := make([]frontend.Variable, endDigit-startDigit) for i := 0; i < len(bits); i++ { bits[i] = c.Bit(i + startDigit) } return bits } - - bits, err := api.Compiler().NewHint(NBitsShifted, endDigit-startDigit, v, startDigit) + shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v) if err != nil { - panic(err) - } - - // we compute 2 sums; - // Σbi ensures that "ignoring" the lowest bits (< startDigit) still is a valid bit decomposition. - // that is, it ensures that bits from startDigit to endDigit * corresponding coefficients (powers of 2 shifted) - // are equal to the input variable - // ΣbiRShift computes the actual result; that is, the Σ (2**i * b[i]) - Σbi := frontend.Variable(0) - ΣbiRShift := frontend.Variable(0) - - cRShift := big.NewInt(1) - c := big.NewInt(1) - c.Lsh(c, uint(startDigit)) - - for i := 0; i < len(bits); i++ { - Σbi = api.MulAcc(Σbi, bits[i], c) - ΣbiRShift = api.MulAcc(ΣbiRShift, bits[i], cRShift) - - c.Lsh(c, 1) - cRShift.Lsh(cRShift, 1) - api.AssertIsBoolean(bits[i]) + panic(fmt.Sprintf("right shift: %v", err)) } - - // constraint Σ (2**i_shift * b[i]) == v - api.AssertIsEqual(Σbi, v) - return ΣbiRShift - + f.checker.Check(shifted[0], endDigit-startDigit) + shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit)) + composed := f.api.Mul(shifted[0], shift) + f.api.AssertIsEqual(composed, v) + return shifted[0] } // AssertLimbsEquality asserts that the limbs represent a same integer value. @@ -107,9 +83,9 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) { // TODO: we previously assumed that one side was "larger" than the other // side, but I think this assumption is not valid anymore if a.overflow > b.overflow { - assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) + f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) } else { - assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) + f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) } } @@ -133,16 +109,14 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { // take only required bits from the most significant limb limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1 } - // bits.ToBinary restricts the least significant NbDigits to be equal to - // the limb value. This is sufficient to restrict for the bitlength and - // we can discard the bits themselves. - bits.ToBinary(f.api, a.Limbs[i], bits.WithNbDigits(limbNbBits)) + f.checker.Check(a.Limbs[i], limbNbBits) } } // AssertIsEqual ensures that a is equal to b modulo the modulus. func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { - // we omit width assertion as it is done in Sub below + f.enforceWidthConditional(a) + f.enforceWidthConditional(b) ba, aConst := f.constantValue(a) bb, bConst := f.constantValue(b) if aConst && bConst { @@ -154,7 +128,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { return } - diff := f.Sub(b, a) + diff := f.subNoReduce(b, a) // we compute k such that diff / p == k // so essentially, we say "I know an element k such that k*p == diff" @@ -170,7 +144,9 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { f.AssertLimbsEquality(diff, kp) } -// AssertIsLessOrEqual ensures that e is less or equal than a. +// AssertIsLessOrEqual ensures that e is less or equal than a. For proper +// bitwise comparison first reduce the element using [Reduce] and then assert +// that its value is less than the modulus using [AssertIsInRange]. func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { // we omit conditional width assertion as is done in ToBits below if e.overflow+a.overflow > 0 { @@ -181,7 +157,7 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { ff := func(xbits, ybits []frontend.Variable) []frontend.Variable { diff := len(xbits) - len(ybits) ybits = append(ybits, make([]frontend.Variable, diff)...) - for i := len(ybits) - diff - 1; i < len(ybits); i++ { + for i := len(ybits) - diff; i < len(ybits); i++ { ybits[i] = 0 } return ybits @@ -202,3 +178,62 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { f.api.AssertIsEqual(ll, 0) } } + +// AssertIsInRange ensures that a is less than the emulated modulus. When we +// call [Reduce] then we only ensure that the result is width-constrained, but +// not actually less than the modulus. This means that the actual value may be +// either x or x + p. For arithmetic it is sufficient, but for binary comparison +// it is not. For binary comparison the values have both to be below the +// modulus. +func (f *Field[T]) AssertIsInRange(a *Element[T]) { + // we omit conditional width assertion as is done in ToBits down the calling stack + f.AssertIsLessOrEqual(a, f.modulusPrev()) +} + +// IsZero returns a boolean indicating if the element is strictly zero. The +// method internally reduces the element and asserts that the value is less than +// the modulus. +func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { + ca := f.Reduce(a) + f.AssertIsInRange(ca) + res := f.api.IsZero(ca.Limbs[0]) + for i := 1; i < len(ca.Limbs); i++ { + f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + } + return res +} + +// // Cmp returns: +// // - -1 if a < b +// // - 0 if a = b +// // - 1 if a > b +// // +// // The method internally reduces the element and asserts that the value is less +// // than the modulus. +// func (f *Field[T]) Cmp(a, b *Element[T]) frontend.Variable { +// ca := f.Reduce(a) +// f.AssertIsInRange(ca) +// cb := f.Reduce(b) +// f.AssertIsInRange(cb) +// var res frontend.Variable = 0 +// for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { +// lmbCmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) +// res = f.api.Select(f.api.IsZero(res), lmbCmp, res) +// } +// return res +// } + +// TODO(@ivokub) +// func (f *Field[T]) AssertIsDifferent(a, b *Element[T]) { +// ca := f.Reduce(a) +// f.AssertIsInRange(ca) +// cb := f.Reduce(b) +// f.AssertIsInRange(cb) +// var res frontend.Variable = 0 +// for i := 0; i < int(f.fParams.NbLimbs()); i++ { +// cmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) +// cmpsq := f.api.Mul(cmp, cmp) +// res = f.api.Add(res, cmpsq) +// } +// f.api.AssertIsDifferent(res, 0) +// } diff --git a/std/math/emulated/field_binary.go b/std/math/emulated/field_binary.go index 213f1654c2..8c949af2e4 100644 --- a/std/math/emulated/field_binary.go +++ b/std/math/emulated/field_binary.go @@ -14,7 +14,11 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable { f.enforceWidthConditional(a) ba, aConst := f.constantValue(a) if aConst { - return f.api.ToBinary(ba, int(f.fParams.BitsPerLimb()*f.fParams.NbLimbs())) + res := make([]frontend.Variable, f.fParams.BitsPerLimb()*f.fParams.NbLimbs()) + for i := range res { + res[i] = ba.Bit(i) + } + return res } var carry frontend.Variable = 0 var fullBits []frontend.Variable diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go new file mode 100644 index 0000000000..52b8e9ddfc --- /dev/null +++ b/std/math/emulated/field_hint.go @@ -0,0 +1,109 @@ +package emulated + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" +) + +func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable { + res := []frontend.Variable{f.fParams.BitsPerLimb(), f.fParams.NbLimbs()} + res = append(res, f.Modulus().Limbs...) + res = append(res, len(nonnativeInputs)) + for i := range nonnativeInputs { + res = append(res, len(nonnativeInputs[i].Limbs)) + res = append(res, nonnativeInputs[i].Limbs...) + } + return res +} + +// UnwrapHint unwraps the native inputs into nonnative inputs. Then it calls +// nonnativeHint function with nonnative inputs. After nonnativeHint returns, it +// decomposes the outputs into limbs. +func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + if len(nativeInputs) < 2 { + return fmt.Errorf("hint wrapper header is 2 elements") + } + if !nativeInputs[0].IsInt64() || !nativeInputs[1].IsInt64() { + return fmt.Errorf("header must be castable to int64") + } + nbBits := int(nativeInputs[0].Int64()) + nbLimbs := int(nativeInputs[1].Int64()) + if len(nativeInputs) < 2+nbLimbs { + return fmt.Errorf("hint wrapper header is 2+nbLimbs elements") + } + nonnativeMod := new(big.Int) + if err := recompose(nativeInputs[2:2+nbLimbs], uint(nbBits), nonnativeMod); err != nil { + return fmt.Errorf("cannot recover nonnative mod: %w", err) + } + if !nativeInputs[2+nbLimbs].IsInt64() { + return fmt.Errorf("number of nonnative elements must be castable to int64") + } + nbInputs := int(nativeInputs[2+nbLimbs].Int64()) + nonnativeInputs := make([]*big.Int, nbInputs) + readPtr := 3 + nbLimbs + for i := 0; i < nbInputs; i++ { + if len(nativeInputs) < readPtr+1 { + return fmt.Errorf("can not read %d-th native input", i) + } + if !nativeInputs[readPtr].IsInt64() { + return fmt.Errorf("corrupted %d-th native input", i) + } + currentInputLen := int(nativeInputs[readPtr].Int64()) + if len(nativeInputs) < (readPtr + 1 + currentInputLen) { + return fmt.Errorf("cannot read %d-th nonnative element", i) + } + nonnativeInputs[i] = new(big.Int) + if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil { + return fmt.Errorf("recompose %d-th element: %w", i, err) + } + readPtr += 1 + currentInputLen + } + if len(nativeOutputs)%nbLimbs != 0 { + return fmt.Errorf("output count doesn't divide limb count") + } + nonnativeOutputs := make([]*big.Int, len(nativeOutputs)/nbLimbs) + for i := range nonnativeOutputs { + nonnativeOutputs[i] = new(big.Int) + } + if err := nonnativeHint(nonnativeMod, nonnativeInputs, nonnativeOutputs); err != nil { + return fmt.Errorf("nonnative hint: %w", err) + } + for i := range nonnativeOutputs { + nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod) + if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil { + return fmt.Errorf("decompose %d-th element: %w", i, err) + } + } + return nil +} + +// NewHint allows to call the emulation hint function hf on inputs, expecting +// nbOutputs results. This function splits internally the emulated element into +// limbs and passes these to the hint function. There is [UnwrapHint] function +// which performs corresponding recomposition of limbs into integer values (and +// vice verse for output). +// +// The hint function for this method is defined as: +// +// func HintFn(mod *big.Int, inputs, outputs []*big.Int) error { +// return emulated.UnwrapHint(inputs, outputs, func(mod *big.Int, inputs, outputs []*big.Int) error { // NB we shadow initial input, output, mod to avoid accidental overwrite! +// // here all inputs and outputs are modulo nonnative mod. we decompose them automatically +// })} +// +// See the example for full written example. +func (f *Field[T]) NewHint(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) ([]*Element[T], error) { + nativeInputs := f.wrapHint(inputs...) + nbNativeOutputs := int(f.fParams.NbLimbs()) * nbOutputs + nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...) + if err != nil { + return nil, fmt.Errorf("call hint: %w", err) + } + outputs := make([]*Element[T], nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = f.packLimbs(nativeOutputs[i*int(f.fParams.NbLimbs()):(i+1)*int(f.fParams.NbLimbs())], true) + } + return outputs, nil +} diff --git a/std/math/emulated/field_hint_test.go b/std/math/emulated/field_hint_test.go new file mode 100644 index 0000000000..9f2d286ec8 --- /dev/null +++ b/std/math/emulated/field_hint_test.go @@ -0,0 +1,116 @@ +package emulated_test + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/math/emulated" +) + +// HintExample is a hint for field emulation which returns the division of the +// first and second input. +func HintExample(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + // nativeInputs are the limbs of the input non-native elements. We wrap the + // actual hint function with [emulated.UnwrapHint] to get actual [*big.Int] + // values of the non-native elements. + return emulated.UnwrapHint(nativeInputs, nativeOutputs, func(mod *big.Int, inputs, outputs []*big.Int) error { + // this hint computes the division of first and second input and returns it. + nominator := inputs[0] + denominator := inputs[1] + res := new(big.Int).ModInverse(denominator, mod) + if res == nil { + return fmt.Errorf("no modular inverse") + } + res.Mul(res, nominator) + res.Mod(res, mod) + outputs[0].Set(res) + return nil + }) + // when the internal hint function returns, the UnwrapHint function + // decomposes the non-native value into limbs. +} + +type emulationHintCircuit[T emulated.FieldParams] struct { + Nominator emulated.Element[T] + Denominator emulated.Element[T] + Expected emulated.Element[T] +} + +func (c *emulationHintCircuit[T]) Define(api frontend.API) error { + field, err := emulated.NewField[T](api) + if err != nil { + return err + } + res, err := field.NewHint(HintExample, 1, &c.Nominator, &c.Denominator) + if err != nil { + return err + } + m := field.Mul(res[0], &c.Denominator) + field.AssertIsEqual(m, &c.Nominator) + field.AssertIsEqual(res[0], &c.Expected) + return nil +} + +// Example of using hints with emulated elements. +func ExampleField_NewHint() { + var a, b, c fr.Element + a.SetRandom() + b.SetRandom() + c.Div(&a, &b) + + circuit := emulationHintCircuit[emulated.BN254Fr]{} + witness := emulationHintCircuit[emulated.BN254Fr]{ + Nominator: emulated.ValueOf[emulated.BN254Fr](a), + Denominator: emulated.ValueOf[emulated.BN254Fr](b), + Expected: emulated.ValueOf[emulated.BN254Fr](c), + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + witnessData, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness parsed") + } + publicWitnessData, err := witnessData.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness parsed") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + proof, err := groth16.Prove(ccs, pk, witnessData, backend.WithSolverOptions(solver.WithHints(HintExample))) + if err != nil { + panic(err) + } else { + fmt.Println("proved") + } + err = groth16.Verify(proof, vk, publicWitnessData) + if err != nil { + panic(err) + } else { + fmt.Println("verified") + } + // Output: compiled + // secret witness parsed + // public witness parsed + // setup done + // proved + // verified +} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index baa7154f5d..26125e7eaf 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -44,6 +44,21 @@ func (f *Field[T]) Inverse(a *Element[T]) *Element[T] { return e } +// Sqrt computes square root of a and returns it. It uses [SqrtHint]. +func (f *Field[T]) Sqrt(a *Element[T]) *Element[T] { + // omit width assertion as is done in Mul below + if !f.fParams.IsPrime() { + panic("modulus not a prime") + } + res, err := f.NewHint(SqrtHint, 1, a) + if err != nil { + panic(fmt.Sprintf("compute sqrt: %v", err)) + } + _a := f.Mul(res[0], res[0]) + f.AssertIsEqual(_a, a) + return res[0] +} + // Add computes a+b and returns it. If the result wouldn't fit into Element, then // first reduces the inputs (larger first) and tries again. Doesn't mutate // inputs. @@ -178,19 +193,19 @@ func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { w := new(big.Int) for c := 1; c <= len(mulResult); c++ { w.SetInt64(1) // c^i - l := a.Limbs[0] - r := b.Limbs[0] - o := mulResult[0] + l := f.api.Mul(a.Limbs[0], 1) + r := f.api.Mul(b.Limbs[0], 1) + o := f.api.Mul(mulResult[0], 1) for i := 1; i < len(mulResult); i++ { w.Lsh(w, uint(c)) if i < len(a.Limbs) { - l = f.api.Add(l, f.api.Mul(a.Limbs[i], w)) + l = f.api.MulAcc(l, a.Limbs[i], w) } if i < len(b.Limbs) { - r = f.api.Add(r, f.api.Mul(b.Limbs[i], w)) + r = f.api.MulAcc(r, b.Limbs[i], w) } - o = f.api.Add(o, f.api.Mul(mulResult[i], w)) + o = f.api.MulAcc(o, mulResult[i], w) } f.api.AssertIsEqual(f.api.Mul(l, r), o) } @@ -214,7 +229,6 @@ func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { if err != nil { panic(fmt.Sprintf("reduction hint: %v", err)) } - // TODO @gbotrel fixme: AssertIsEqual(a, e) crashes Pairing test f.AssertIsEqual(e, a) return e } @@ -225,9 +239,20 @@ func (f *Field[T]) Sub(a, b *Element[T]) *Element[T] { return f.reduceAndOp(f.sub, f.subPreCond, a, b) } +// subReduce returns a-b and returns it. Contrary to [Field[T].Sub] method this +// method does not reduce the inputs if the result would overflow. This method +// is currently only used as a subroutine in [Field[T].Reduce] method to avoid +// infinite recursion when we are working exactly on the overflow limits. +func (f *Field[T]) subNoReduce(a, b *Element[T]) *Element[T] { + nextOverflow, _ := f.subPreCond(a, b) + // we ignore error as it only indicates if we should reduce or not. But we + // are in non-reducing version of sub. + return f.sub(a, b, nextOverflow) +} + func (f *Field[T]) subPreCond(a, b *Element[T]) (nextOverflow uint, err error) { - reduceRight := a.overflow < b.overflow+2 - nextOverflow = max(b.overflow+2, a.overflow) + reduceRight := a.overflow < (b.overflow + 1) + nextOverflow = max(b.overflow+1, a.overflow) + 1 if nextOverflow > f.maxOverflow() { err = overflowError{op: "sub", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} } @@ -263,7 +288,7 @@ func (f *Field[T]) Neg(a *Element[T]) *Element[T] { return f.Sub(f.Zero(), a) } -// Select sets e to a if selector == 0 and to b otherwise. Sets the number of +// Select sets e to a if selector == 1 and to b otherwise. Sets the number of // limbs and overflow of the result to be the maximum of the limb lengths and // overflows. If the inputs are strongly unbalanced, then it would better to // reduce the result after the operation. diff --git a/std/math/emulated/field_test.go b/std/math/emulated/field_test.go index 927b23bc59..d1124f4e82 100644 --- a/std/math/emulated/field_test.go +++ b/std/math/emulated/field_test.go @@ -8,7 +8,7 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/test" ) diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 9e799e4e17..e0aeff00ab 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -4,7 +4,7 @@ import ( "fmt" "math/big" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) @@ -12,18 +12,19 @@ import ( // inside a func, then it becomes anonymous and hint identification is screwed. func init() { - hint.Register(GetHints()...) + solver.RegisterHint(GetHints()...) } // GetHints returns all hint functions used in the package. -func GetHints() []hint.Function { - return []hint.Function{ +func GetHints() []solver.Hint { + return []solver.Hint{ DivHint, QuoHint, InverseHint, MultiplicationHint, RemHint, - NBitsShifted, + RightShift, + SqrtHint, } } @@ -287,13 +288,38 @@ func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error return nbBits, nbLimbs, x, y, nil } -// NBitsShifted returns the first bits of the input, with a shift. The number of returned bits is -// defined by the length of the results slice. -func NBitsShifted(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - shift := inputs[1].Uint64() // TODO @gbotrel validate input vs perf in large circuits. - for i := 0; i < len(results); i++ { - results[i].SetUint64(uint64(n.Bit(i + int(shift)))) - } +// RightShift shifts input by the given number of bits. Expects two inputs: +// - first input is the shift, will be represented as uint64; +// - second input is the value to be shifted. +// +// Returns a single output which is the value shifted. Errors if number of +// inputs is not 2 and number of outputs is not 1. +func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two inputs") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting single output") + } + shift := inputs[0].Uint64() + outputs[0].Rsh(inputs[1], uint(shift)) return nil } + +// SqrtHint compute square root of the input. +func SqrtHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 1 { + return fmt.Errorf("expecting single input") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting single output") + } + res := new(big.Int) + if res.ModSqrt(inputs[0], field) == nil { + return fmt.Errorf("no square root") + } + outputs[0].Set(res) + return nil + }) +} diff --git a/std/math/emulated/params.go b/std/math/emulated/params.go index 9f89b0da09..fc3d8abfc1 100644 --- a/std/math/emulated/params.go +++ b/std/math/emulated/params.go @@ -3,10 +3,20 @@ package emulated import ( "math/big" - "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/std/math/emulated/emparams" ) -// FieldParams describes the emulated field characteristics +// FieldParams describes the emulated field characteristics. For a list of +// included built-in emulation params refer to the [emparams] package. +// For backwards compatibility, the current package contains the following +// parameters: +// - [Goldilocks] +// - [Secp256k1Fp] and [Secp256k1Fr] +// - [BN254Fp] and [BN254Fr] +// - [BLS12377Fp] +// - [BLS12381Fp] and [BLS12381Fr] +// - [P256Fp] and [P256Fr] +// - [P384Fp] and [P384Fr] type FieldParams interface { NbLimbs() uint // number of limbs to represent field element BitsPerLimb() uint // number of bits per limb. Top limb may contain less than limbSize bits. @@ -14,75 +24,17 @@ type FieldParams interface { Modulus() *big.Int // returns modulus. Do not modify. } -var ( - qSecp256k1, rSecp256k1 *big.Int - qGoldilocks *big.Int +type ( + Goldilocks = emparams.Goldilocks + Secp256k1Fp = emparams.Secp256k1Fp + Secp256k1Fr = emparams.Secp256k1Fr + BN254Fp = emparams.BN254Fp + BN254Fr = emparams.BN254Fr + BLS12377Fp = emparams.BLS12377Fp + BLS12381Fp = emparams.BLS12381Fp + BLS12381Fr = emparams.BLS12381Fr + P256Fp = emparams.P256Fp + P256Fr = emparams.P256Fr + P384Fp = emparams.P384Fp + P384Fr = emparams.P384Fr ) - -func init() { - qSecp256k1, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) - rSecp256k1, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) - qGoldilocks, _ = new(big.Int).SetString("ffffffff00000001", 16) -} - -// Goldilocks provide type parametrization for emulated field on 1 limb of width 64bits -// for modulus 0xffffffff00000001 -type Goldilocks struct{} - -func (fp Goldilocks) NbLimbs() uint { return 1 } -func (fp Goldilocks) BitsPerLimb() uint { return 64 } -func (fp Goldilocks) IsPrime() bool { return true } -func (fp Goldilocks) Modulus() *big.Int { return qGoldilocks } - -// Secp256k1Fp provide type parametrization for emulated field on 4 limb of width 64bits -// for modulus 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. -// This is the base field of secp256k1 curve -type Secp256k1Fp struct{} - -func (fp Secp256k1Fp) NbLimbs() uint { return 4 } -func (fp Secp256k1Fp) BitsPerLimb() uint { return 64 } -func (fp Secp256k1Fp) IsPrime() bool { return true } -func (fp Secp256k1Fp) Modulus() *big.Int { return qSecp256k1 } - -// Secp256k1Fr provides type parametrization for emulated field on 4 limbs of width 64bits -// for modulus 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141. -// This is the scalar field of secp256k1 curve. -type Secp256k1Fr struct{} - -func (fp Secp256k1Fr) NbLimbs() uint { return 4 } -func (fp Secp256k1Fr) BitsPerLimb() uint { return 64 } -func (fp Secp256k1Fr) IsPrime() bool { return true } -func (fp Secp256k1Fr) Modulus() *big.Int { return rSecp256k1 } - -// BN254Fp provide type parametrization for emulated field on 4 limb of width -// 64bits for modulus -// 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47. This is -// the base field of the BN254 curve. -type BN254Fp struct{} - -func (fp BN254Fp) NbLimbs() uint { return 4 } -func (fp BN254Fp) BitsPerLimb() uint { return 64 } -func (fp BN254Fp) IsPrime() bool { return true } -func (fp BN254Fp) Modulus() *big.Int { return ecc.BN254.BaseField() } - -// BN254Fr provides type parametrisation for emulated field on 4 limbs of width -// 64bits for modulus -// 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001. This is -// the scalar field of the BN254 curve. -type BN254Fr struct{} - -func (fp BN254Fr) NbLimbs() uint { return 4 } -func (fp BN254Fr) BitsPerLimb() uint { return 64 } -func (fp BN254Fr) IsPrime() bool { return true } -func (fp BN254Fr) Modulus() *big.Int { return ecc.BN254.ScalarField() } - -// BLS12377Fp provide type parametrization for emulated field on 6 limb of width -// 64bits for modulus -// 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001. -// This is the base field of the BLS12-377 curve. -type BLS12377Fp struct{} - -func (fp BLS12377Fp) NbLimbs() uint { return 6 } -func (fp BLS12377Fp) BitsPerLimb() uint { return 64 } -func (fp BLS12377Fp) IsPrime() bool { return true } -func (fp BLS12377Fp) Modulus() *big.Int { return ecc.BLS12_377.BaseField() } diff --git a/std/math/uints/hints.go b/std/math/uints/hints.go new file mode 100644 index 0000000000..40dbbd9917 --- /dev/null +++ b/std/math/uints/hints.go @@ -0,0 +1,53 @@ +package uints + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/constraint/solver" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +func GetHints() []solver.Hint { + return []solver.Hint{ + andHint, + xorHint, + toBytes, + } +} + +func xorHint(_ *big.Int, inputs, outputs []*big.Int) error { + outputs[0].Xor(inputs[0], inputs[1]) + return nil +} + +func andHint(_ *big.Int, inputs, outputs []*big.Int) error { + outputs[0].And(inputs[0], inputs[1]) + return nil +} + +func toBytes(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("input must be 2 elements") + } + if !inputs[0].IsUint64() { + return fmt.Errorf("first input must be uint64") + } + nbLimbs := int(inputs[0].Uint64()) + if len(outputs) != nbLimbs { + return fmt.Errorf("output must be 8 elements") + } + if !inputs[1].IsUint64() { + return fmt.Errorf("input must be 64 bits") + } + base := new(big.Int).Lsh(big.NewInt(1), uint(8)) + tmp := new(big.Int).Set(inputs[1]) + for i := 0; i < nbLimbs; i++ { + outputs[i].Mod(tmp, base) + tmp.Rsh(tmp, 8) + } + return nil +} diff --git a/std/math/uints/uint8.go b/std/math/uints/uint8.go new file mode 100644 index 0000000000..cec591d10c --- /dev/null +++ b/std/math/uints/uint8.go @@ -0,0 +1,342 @@ +// Package uints implements optimised byte and long integer operations. +// +// Usually arithmetic in a circuit is performed in the native field, which is of +// prime order. However, for compatibility with native operations we rely on +// operating on smaller primitive types as 8-bit, 32-bit and 64-bit integer. +// Naively, these operations have to be implemented bitwise as there are no +// closed equations for boolean operations (XOR, AND, OR). +// +// However, the bitwise approach is very inefficient and leads to several +// constraints per bit. Accumulating over a long integer, it leads to very +// inefficients circuits. +// +// This package performs boolean operations using lookup tables on bytes. So, +// long integers are split into 4 or 8 bytes and we perform the operations +// bytewise. In the lookup tables, we store results for all possible 2^8×2^8 +// inputs. With this approach, every bytewise operation costs as single lookup, +// which depending on the backend is relatively cheap (one to three +// constraints). +// +// NB! The package is still work in progress. The interfaces and implementation +// details most certainly changes over time. We cannot ensure the soundness of +// the operations. +package uints + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/internal/logderivprecomp" + "github.com/consensys/gnark/std/math/bitslice" + "github.com/consensys/gnark/std/rangecheck" +) + +// TODO: if internal then enforce range check! + +// TODO: all operations can take rand linear combinations instead. Then instead +// of one check can perform multiple at the same time. + +// TODO: implement versions which take multiple inputs. Maybe can combine multiple together + +// TODO: instantiate tables only when we first query. Maybe do not need to build! + +// TODO: maybe can store everything in a single table? Later! Or if we have a +// lot of queries then makes sense to extract into separate table? + +// TODO: in ValueOf ensure consistency + +// TODO: distinguish between when we set constant in-circuit or witness +// assignment. For constant we don't have to range check but for witness +// assignment we have to. + +// TODO: add something which allows to store array in native element + +// TODO: add methods for checking if U8/Long is constant. + +// TODO: should something for byte-only ops. Implement a type and then embed it in BinaryField + +// TODO: add helper method to call hints which allows to pass in uint8s (bytes) +// and returns bytes. Then can to byte array manipluation nicely. It is useful +// for X509. For the implementation we want to pack as much bytes into a field +// element as possible. + +// TODO: methods for converting uint array into emulated element and native +// element. Most probably should add the implementation for non-native in its +// package, but for native we should add it here. + +type U8 struct { + Val frontend.Variable + internal bool +} + +// GnarkInitHook describes how to initialise the element. +func (e *U8) GnarkInitHook() { + if e.Val == nil { + e.Val = 0 + e.internal = false // we need to constrain in later. + } +} + +type U64 [8]U8 +type U32 [4]U8 + +type Long interface{ U32 | U64 } + +type BinaryField[T U32 | U64] struct { + api frontend.API + xorT, andT *logderivprecomp.Precomputed + rchecker frontend.Rangechecker + allOne U8 +} + +func New[T Long](api frontend.API) (*BinaryField[T], error) { + xorT, err := logderivprecomp.New(api, xorHint, []uint{8}) + if err != nil { + return nil, fmt.Errorf("new xor table: %w", err) + } + andT, err := logderivprecomp.New(api, andHint, []uint{8}) + if err != nil { + return nil, fmt.Errorf("new and table: %w", err) + } + rchecker := rangecheck.New(api) + bf := &BinaryField[T]{ + api: api, + xorT: xorT, + andT: andT, + rchecker: rchecker, + } + // TODO: this is const. add way to init constants + allOne := bf.ByteValueOf(0xff) + bf.allOne = allOne + return bf, nil +} + +func NewU8(v uint8) U8 { + // TODO: don't have to check constants + return U8{Val: v, internal: true} +} + +func NewU32(v uint32) U32 { + return [4]U8{ + NewU8(uint8((v >> (0 * 8)) & 0xff)), + NewU8(uint8((v >> (1 * 8)) & 0xff)), + NewU8(uint8((v >> (2 * 8)) & 0xff)), + NewU8(uint8((v >> (3 * 8)) & 0xff)), + } +} + +func NewU64(v uint64) U64 { + return [8]U8{ + NewU8(uint8((v >> (0 * 8)) & 0xff)), + NewU8(uint8((v >> (1 * 8)) & 0xff)), + NewU8(uint8((v >> (2 * 8)) & 0xff)), + NewU8(uint8((v >> (3 * 8)) & 0xff)), + NewU8(uint8((v >> (4 * 8)) & 0xff)), + NewU8(uint8((v >> (5 * 8)) & 0xff)), + NewU8(uint8((v >> (6 * 8)) & 0xff)), + NewU8(uint8((v >> (7 * 8)) & 0xff)), + } +} + +func NewU8Array(v []uint8) []U8 { + ret := make([]U8, len(v)) + for i := range v { + ret[i] = NewU8(v[i]) + } + return ret +} + +func NewU32Array(v []uint32) []U32 { + ret := make([]U32, len(v)) + for i := range v { + ret[i] = NewU32(v[i]) + } + return ret +} + +func NewU64Array(v []uint64) []U64 { + ret := make([]U64, len(v)) + for i := range v { + ret[i] = NewU64(v[i]) + } + return ret +} + +func (bf *BinaryField[T]) ByteValueOf(a frontend.Variable) U8 { + bf.rchecker.Check(a, 8) + return U8{Val: a, internal: true} +} + +func (bf *BinaryField[T]) ValueOf(a frontend.Variable) T { + var r T + bts, err := bf.api.Compiler().NewHint(toBytes, len(r), len(r), a) + if err != nil { + panic(err) + } + // TODO: add constraint which ensures that map back to + for i := range bts { + r[i] = bf.ByteValueOf(bts[i]) + } + return r +} + +func (bf *BinaryField[T]) ToValue(a T) frontend.Variable { + v := make([]frontend.Variable, len(a)) + for i := range v { + v[i] = bf.api.Mul(a[i].Val, 1<<(i*8)) + } + vv := bf.api.Add(v[0], v[1], v[2:]...) + return vv +} + +func (bf *BinaryField[T]) PackMSB(a ...U8) T { + var ret T + for i := range a { + ret[len(a)-i-1] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) PackLSB(a ...U8) T { + var ret T + for i := range a { + ret[i] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) UnpackMSB(a T) []U8 { + ret := make([]U8, len(a)) + for i := 0; i < len(a); i++ { + ret[len(a)-i-1] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) UnpackLSB(a T) []U8 { + // cannot deduce that a can be cast to []U8 + ret := make([]U8, len(a)) + for i := 0; i < len(a); i++ { + ret[i] = a[i] + } + return ret +} + +func (bf *BinaryField[T]) twoArgFn(tbl *logderivprecomp.Precomputed, a ...U8) U8 { + ret := tbl.Query(a[0].Val, a[1].Val)[0] + for i := 2; i < len(a); i++ { + ret = tbl.Query(ret, a[i].Val)[0] + } + return U8{Val: ret} +} + +func (bf *BinaryField[T]) twoArgWideFn(tbl *logderivprecomp.Precomputed, a ...T) T { + var r T + for i, v := range reslice(a) { + r[i] = bf.twoArgFn(tbl, v...) + } + return r +} + +func (bf *BinaryField[T]) And(a ...T) T { return bf.twoArgWideFn(bf.andT, a...) } +func (bf *BinaryField[T]) Xor(a ...T) T { return bf.twoArgWideFn(bf.xorT, a...) } + +func (bf *BinaryField[T]) not(a U8) U8 { + ret := bf.xorT.Query(a.Val, bf.allOne.Val) + return U8{Val: ret[0]} +} + +func (bf *BinaryField[T]) Not(a T) T { + var r T + for i := 0; i < len(a); i++ { + r[i] = bf.not(a[i]) + } + return r +} + +func (bf *BinaryField[T]) Add(a ...T) T { + va := make([]frontend.Variable, len(a)) + for i := range a { + va[i] = bf.ToValue(a[i]) + } + vres := bf.api.Add(va[0], va[1], va[2:]...) + res := bf.ValueOf(vres) + // TODO: should also check the that carry we omitted is correct. + return res +} + +func (bf *BinaryField[T]) Lrot(a T, c int) T { + l := len(a) + if c < 0 { + c = l*8 + c + } + shiftBl := c / 8 + shiftBt := c % 8 + revShiftBt := 8 - shiftBt + if revShiftBt == 8 { + revShiftBt = 0 + } + partitioned := make([][2]frontend.Variable, l) + for i := range partitioned { + lower, upper := bitslice.Partition(bf.api, a[i].Val, uint(revShiftBt), bitslice.WithNbDigits(8)) + partitioned[i] = [2]frontend.Variable{lower, upper} + } + var ret T + for i := 0; i < l; i++ { + if shiftBt != 0 { + ret[(i+shiftBl)%l].Val = bf.api.Add(bf.api.Mul(1<<(shiftBt), partitioned[i][0]), partitioned[(i+l-1)%l][1]) + } else { + ret[(i+shiftBl)%l].Val = partitioned[i][1] + } + } + return ret +} + +func (bf *BinaryField[T]) Rshift(a T, c int) T { + shiftBl := c / 8 + shiftBt := c % 8 + partitioned := make([][2]frontend.Variable, len(a)-shiftBl) + for i := range partitioned { + lower, upper := bitslice.Partition(bf.api, a[i+shiftBl].Val, uint(shiftBt), bitslice.WithNbDigits(8)) + partitioned[i] = [2]frontend.Variable{lower, upper} + } + var ret T + for i := 0; i < len(a)-shiftBl-1; i++ { + if shiftBt != 0 { + ret[i].Val = bf.api.Add(partitioned[i][1], bf.api.Mul(1<<(8-shiftBt), partitioned[i+1][0])) + } else { + ret[i].Val = partitioned[i][1] + } + } + ret[len(a)-shiftBl-1].Val = partitioned[len(a)-shiftBl-1][1] + for i := len(a) - shiftBl; i < len(ret); i++ { + ret[i] = NewU8(0) + } + return ret +} + +func (bf *BinaryField[T]) ByteAssertEq(a, b U8) { + bf.api.AssertIsEqual(a.Val, b.Val) +} + +func (bf *BinaryField[T]) AssertEq(a, b T) { + for i := 0; i < len(a); i++ { + bf.ByteAssertEq(a[i], b[i]) + } +} + +func reslice[T U32 | U64](in []T) [][]U8 { + if len(in) == 0 { + panic("zero-length input") + } + ret := make([][]U8, len(in[0])) + for i := range ret { + ret[i] = make([]U8, len(in)) + } + for i := 0; i < len(in); i++ { + for j := 0; j < len(in[0]); j++ { + ret[j][i] = in[i][j] + } + } + return ret +} diff --git a/std/math/uints/uint8_test.go b/std/math/uints/uint8_test.go new file mode 100644 index 0000000000..bc24a8f4db --- /dev/null +++ b/std/math/uints/uint8_test.go @@ -0,0 +1,81 @@ +package uints + +import ( + "math/bits" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type lrotCirc struct { + In U32 + Out U32 + Shift int +} + +func (c *lrotCirc) Define(api frontend.API) error { + uapi, err := New[U32](api) + if err != nil { + return err + } + res := uapi.Lrot(c.In, c.Shift) + uapi.AssertEq(c.Out, res) + return nil +} + +func TestLeftRotation(t *testing.T) { + assert := test.NewAssert(t) + var err error + err = test.IsSolved(&lrotCirc{Shift: 4}, &lrotCirc{In: NewU32(0x12345678), Shift: 4, Out: NewU32(bits.RotateLeft32(0x12345678, 4))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: 14}, &lrotCirc{In: NewU32(0x12345678), Shift: 14, Out: NewU32(bits.RotateLeft32(0x12345678, 14))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: 3}, &lrotCirc{In: NewU32(0x12345678), Shift: 3, Out: NewU32(bits.RotateLeft32(0x12345678, 3))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: 11}, &lrotCirc{In: NewU32(0x12345678), Shift: 11, Out: NewU32(bits.RotateLeft32(0x12345678, 11))}, ecc.BN254.ScalarField()) + assert.NoError(err) + // full block + err = test.IsSolved(&lrotCirc{Shift: 16}, &lrotCirc{In: NewU32(0x12345678), Shift: 16, Out: NewU32(bits.RotateLeft32(0x12345678, 16))}, ecc.BN254.ScalarField()) + assert.NoError(err) + // negative rotations + err = test.IsSolved(&lrotCirc{Shift: -4}, &lrotCirc{In: NewU32(0x12345678), Shift: -4, Out: NewU32(bits.RotateLeft32(0x12345678, -4))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -14}, &lrotCirc{In: NewU32(0x12345678), Shift: -14, Out: NewU32(bits.RotateLeft32(0x12345678, -14))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -3}, &lrotCirc{In: NewU32(0x12345678), Shift: -3, Out: NewU32(bits.RotateLeft32(0x12345678, -3))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -11}, &lrotCirc{In: NewU32(0x12345678), Shift: -11, Out: NewU32(bits.RotateLeft32(0x12345678, -11))}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&lrotCirc{Shift: -16}, &lrotCirc{In: NewU32(0x12345678), Shift: -16, Out: NewU32(bits.RotateLeft32(0x12345678, -16))}, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type rshiftCircuit struct { + In, Expected U32 + Shift int +} + +func (c *rshiftCircuit) Define(api frontend.API) error { + uapi, err := New[U32](api) + if err != nil { + return err + } + res := uapi.Rshift(c.In, c.Shift) + uapi.AssertEq(res, c.Expected) + return nil +} + +func TestRshift(t *testing.T) { + assert := test.NewAssert(t) + var err error + err = test.IsSolved(&rshiftCircuit{Shift: 4}, &rshiftCircuit{Shift: 4, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 4)}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&rshiftCircuit{Shift: 12}, &rshiftCircuit{Shift: 12, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 12)}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&rshiftCircuit{Shift: 3}, &rshiftCircuit{Shift: 3, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 3)}, ecc.BN254.ScalarField()) + assert.NoError(err) + err = test.IsSolved(&rshiftCircuit{Shift: 11}, &rshiftCircuit{Shift: 11, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 11)}, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/multicommit/doc_test.go b/std/multicommit/doc_test.go new file mode 100644 index 0000000000..6d031facf6 --- /dev/null +++ b/std/multicommit/doc_test.go @@ -0,0 +1,104 @@ +package multicommit_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/multicommit" +) + +// MultipleCommitmentCircuit is an example circuit showing usage of multiple +// independent commitments in-circuit. +type MultipleCommitmentsCircuit struct { + Secrets [4]frontend.Variable +} + +func (c *MultipleCommitmentsCircuit) Define(api frontend.API) error { + // first callback receives first unique commitment derived from the root commitment + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // compute (X-s[0]) * (X-s[1]) for a random X + res := api.Mul(api.Sub(commitment, c.Secrets[0]), api.Sub(commitment, c.Secrets[1])) + api.AssertIsDifferent(res, 0) + return nil + }, c.Secrets[:2]...) + + // second callback receives second unique commitment derived from the root commitment + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // compute (X-s[2]) * (X-s[3]) for a random X + res := api.Mul(api.Sub(commitment, c.Secrets[2]), api.Sub(commitment, c.Secrets[3])) + api.AssertIsDifferent(res, 0) + return nil + }, c.Secrets[2:4]...) + + // we do not have to pass any variables in if other calls to [WithCommitment] have + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // compute (X-s[0]) for a random X + api.AssertIsDifferent(api.Sub(commitment, c.Secrets[0]), 0) + return nil + }) + + // we can share variables between the callbacks + var shared, stored frontend.Variable + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + shared = api.Add(c.Secrets[0], commitment) + stored = commitment + return nil + }) + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsEqual(api.Sub(shared, stored), c.Secrets[0]) + return nil + }) + return nil +} + +// Full written on how to use multiple commitments in a circuit. +func ExampleWithCommitment() { + circuit := MultipleCommitmentsCircuit{} + assignment := MultipleCommitmentsCircuit{Secrets: [4]frontend.Variable{1, 2, 3, 4}} + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: + // compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/multicommit/nativecommit.go b/std/multicommit/nativecommit.go new file mode 100644 index 0000000000..f3f81f2a4b --- /dev/null +++ b/std/multicommit/nativecommit.go @@ -0,0 +1,137 @@ +// Package multicommit implements commitment expansion. +// +// If the builder implements [frontend.Committer] interface, then we can commit +// to the variables and get a commitment which can be used as a unique +// randomness in the circuit. For current builders implementing this interface, +// the function can only be called once in a circuit. This makes it difficult to +// compose different gadgets which require randomness. +// +// This package extends the commitment interface by allowing to receive several +// functions unique commitment multiple times. It does this by collecting all +// variables to commit and the callbacks which want to access a commitment. Then +// we internally defer a function which computes the commitment over all input +// committed variables and then uses this commitment to derive a per-callback +// unique commitment. The callbacks are then called with these unique derived +// commitments instead. +package multicommit + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/hash/mimc" +) + +type multicommitter struct { + closed bool + vars []frontend.Variable + cbs []WithCommitmentFn +} + +type ctxMulticommitterKey struct{} + +// Initialize creates a multicommitter in the cache and defers its finalization. +// This can be useful in a context where `api.Defer` is already called and where +// calls to `WithCommitment` are deferred. Panics if the multicommit is already +// initialized. +func Initialize(api frontend.API) { + kv, ok := api.(kvstore.Store) + if !ok { + // if the builder doesn't implement key-value store then cannot store + // multi-committer in cache. + panic("builder should implement key-value store") + } + + // check if the multicommit is already initialized + mc := kv.GetKeyValue(ctxMulticommitterKey{}) + if mc != nil { + panic("multicommit is already initialized") + } + + // initialize the multicommit + mct := &multicommitter{} + kv.SetKeyValue(ctxMulticommitterKey{}, mct) + api.Compiler().Defer(mct.commitAndCall) +} + +// getCached gets the cached committer from the key-value storage. If it is not +// there then creates, stores and defers it, and then returns. +func getCached(api frontend.API) *multicommitter { + kv, ok := api.(kvstore.Store) + if !ok { + // if the builder doesn't implement key-value store then cannot store + // multi-committer in cache. + panic("builder should implement key-value store") + } + mc := kv.GetKeyValue(ctxMulticommitterKey{}) + if mc != nil { + if mct, ok := mc.(*multicommitter); ok { + return mct + } else { + panic("stored multicommiter is of invalid type") + } + } + mct := &multicommitter{} + kv.SetKeyValue(ctxMulticommitterKey{}, mct) + api.Compiler().Defer(mct.commitAndCall) + return mct +} + +func (mct *multicommitter) commitAndCall(api frontend.API) error { + // close collecting input in case anyone wants to check more variables to commit to. + mct.closed = true + if len(mct.cbs) == 0 { + // shouldn't happen. we defer this function on creating multicommitter + // instance. It is probably some race. + panic("calling commiter with zero callbacks") + } + commiter, ok := api.Compiler().(frontend.Committer) + if !ok { + panic("compiler doesn't implement frontend.Committer") + } + cmt, err := commiter.Commit(mct.vars...) + if err != nil { + return fmt.Errorf("commit: %w", err) + } + if len(mct.cbs) == 1 { + if err = mct.cbs[0](api, cmt); err != nil { + return fmt.Errorf("single callback: %w", err) + } + } else { + hasher, err := mimc.NewMiMC(api) + if err != nil { + return fmt.Errorf("new hasher: %w", err) + } + for i, cb := range mct.cbs { + hasher.Reset() + hasher.Write(i+1, cmt) + localcmt := hasher.Sum() + if err = cb(api, localcmt); err != nil { + return fmt.Errorf("with commitment callback %d: %w", i, err) + } + } + } + return nil +} + +// WithCommitmentFn is the function which is called asynchronously after all +// variables have been committed to. See [WithCommitment] for scheduling a +// function of this type. Every called functions received a distinct commitment +// built from a single root. +// +// It is invalid to call [WithCommitment] in this method recursively and this +// leads to panic. However, the method can call defer for other callbacks. +type WithCommitmentFn func(api frontend.API, commitment frontend.Variable) error + +// WithCommitment schedules the function cb to be called with a unique +// commitment. We append the variables committedVariables to be committed to +// with the native [frontend.Committer] interface. +func WithCommitment(api frontend.API, cb WithCommitmentFn, committedVariables ...frontend.Variable) { + mct := getCached(api) + if mct.closed { + panic("called WithCommitment recursively") + } + mct.vars = append(mct.vars, committedVariables...) + mct.cbs = append(mct.cbs, cb) +} diff --git a/std/multicommit/nativecommit_test.go b/std/multicommit/nativecommit_test.go new file mode 100644 index 0000000000..4b570b7f33 --- /dev/null +++ b/std/multicommit/nativecommit_test.go @@ -0,0 +1,72 @@ +package multicommit + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type noRecursionCircuit struct { + X frontend.Variable +} + +func (c *noRecursionCircuit) Define(api frontend.API) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { return nil }, commitment) + return nil + }, c.X) + return nil +} + +func TestNoRecursion(t *testing.T) { + circuit := noRecursionCircuit{} + assert := test.NewAssert(t) + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + assert.Error(err) +} + +type multipleCommitmentCircuit struct { + X frontend.Variable +} + +func (c *multipleCommitmentCircuit) Define(api frontend.API) error { + var stored frontend.Variable + // first callback receives first unique commitment derived from the root commitment + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsDifferent(c.X, commitment) + stored = commitment + return nil + }, c.X) + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsDifferent(stored, commitment) + return nil + }, c.X) + return nil +} + +func TestMultipleCommitments(t *testing.T) { + circuit := multipleCommitmentCircuit{} + assignment := multipleCommitmentCircuit{X: 10} + assert := test.NewAssert(t) + assert.ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) // right now PLONK doesn't implement commitment +} + +type noCommitVariable struct { + X frontend.Variable +} + +func (c *noCommitVariable) Define(api frontend.API) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { return nil }) + return nil +} + +func TestNoCommitVariable(t *testing.T) { + circuit := noCommitVariable{} + assert := test.NewAssert(t) + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + assert.Error(err) +} diff --git a/std/permutation/keccakf/keccak_test.go b/std/permutation/keccakf/keccak_test.go index fe4fca829d..b7151a893c 100644 --- a/std/permutation/keccakf/keccak_test.go +++ b/std/permutation/keccakf/keccak_test.go @@ -16,7 +16,13 @@ type keccakfCircuit struct { } func (c *keccakfCircuit) Define(api frontend.API) error { - res := keccakf.Permute(api, c.In) + var res [25]frontend.Variable + for i := range res { + res[i] = c.In[i] + } + for i := 0; i < 2; i++ { + res = keccakf.Permute(api, res) + } for i := range res { api.AssertIsEqual(res[i], c.Expected[i]) } @@ -25,15 +31,22 @@ func (c *keccakfCircuit) Define(api frontend.API) error { func TestKeccakf(t *testing.T) { var nativeIn [25]uint64 + var res [25]uint64 for i := range nativeIn { nativeIn[i] = 2 + res[i] = 2 + } + for i := 0; i < 2; i++ { + res = keccakF1600(res) } - nativeOut := keccakF1600(nativeIn) witness := keccakfCircuit{} for i := range nativeIn { witness.In[i] = nativeIn[i] - witness.Expected[i] = nativeOut[i] + witness.Expected[i] = res[i] } assert := test.NewAssert(t) - assert.ProverSucceeded(&keccakfCircuit{}, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16, backend.PLONK)) + assert.ProverSucceeded(&keccakfCircuit{}, &witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16, backend.PLONK), + test.NoFuzzing()) } diff --git a/std/permutation/keccakf/keccakf.go b/std/permutation/keccakf/keccakf.go index fbaedebb72..8f5e3ae346 100644 --- a/std/permutation/keccakf/keccakf.go +++ b/std/permutation/keccakf/keccakf.go @@ -12,33 +12,34 @@ package keccakf import ( "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" ) -var rc = [24]xuint64{ - constUint64(0x0000000000000001), - constUint64(0x0000000000008082), - constUint64(0x800000000000808A), - constUint64(0x8000000080008000), - constUint64(0x000000000000808B), - constUint64(0x0000000080000001), - constUint64(0x8000000080008081), - constUint64(0x8000000000008009), - constUint64(0x000000000000008A), - constUint64(0x0000000000000088), - constUint64(0x0000000080008009), - constUint64(0x000000008000000A), - constUint64(0x000000008000808B), - constUint64(0x800000000000008B), - constUint64(0x8000000000008089), - constUint64(0x8000000000008003), - constUint64(0x8000000000008002), - constUint64(0x8000000000000080), - constUint64(0x000000000000800A), - constUint64(0x800000008000000A), - constUint64(0x8000000080008081), - constUint64(0x8000000000008080), - constUint64(0x0000000080000001), - constUint64(0x8000000080008008), +var rc = [24]uints.U64{ + uints.NewU64(0x0000000000000001), + uints.NewU64(0x0000000000008082), + uints.NewU64(0x800000000000808A), + uints.NewU64(0x8000000080008000), + uints.NewU64(0x000000000000808B), + uints.NewU64(0x0000000080000001), + uints.NewU64(0x8000000080008081), + uints.NewU64(0x8000000000008009), + uints.NewU64(0x000000000000008A), + uints.NewU64(0x0000000000000088), + uints.NewU64(0x0000000080008009), + uints.NewU64(0x000000008000000A), + uints.NewU64(0x000000008000808B), + uints.NewU64(0x800000000000008B), + uints.NewU64(0x8000000000008089), + uints.NewU64(0x8000000000008003), + uints.NewU64(0x8000000000008002), + uints.NewU64(0x8000000000000080), + uints.NewU64(0x000000000000800A), + uints.NewU64(0x800000008000000A), + uints.NewU64(0x8000000080008081), + uints.NewU64(0x8000000000008080), + uints.NewU64(0x0000000080000001), + uints.NewU64(0x8000000080008008), } var rotc = [24]int{ 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, @@ -53,32 +54,34 @@ var piln = [24]int{ // vector. The input array must consist of 64-bit (unsigned) integers. The // returned array also contains 64-bit unsigned integers. func Permute(api frontend.API, a [25]frontend.Variable) [25]frontend.Variable { - var in [25]xuint64 - uapi := newUint64API(api) + var in [25]uints.U64 + uapi, err := uints.New[uints.U64](api) + if err != nil { + panic(err) // TODO: return error instead + } for i := range a { - in[i] = uapi.asUint64(a[i]) + in[i] = uapi.ValueOf(a[i]) } - res := permute(api, in) + res := permute(uapi, in) var out [25]frontend.Variable for i := range out { - out[i] = uapi.fromUint64(res[i]) + out[i] = uapi.ToValue(res[i]) } return out } -func permute(api frontend.API, st [25]xuint64) [25]xuint64 { - uapi := newUint64API(api) - var t xuint64 - var bc [5]xuint64 +func permute(uapi *uints.BinaryField[uints.U64], st [25]uints.U64) [25]uints.U64 { + var t uints.U64 + var bc [5]uints.U64 for r := 0; r < 24; r++ { // theta for i := 0; i < 5; i++ { - bc[i] = uapi.xor(st[i], st[i+5], st[i+10], st[i+15], st[i+20]) + bc[i] = uapi.Xor(st[i], st[i+5], st[i+10], st[i+15], st[i+20]) } for i := 0; i < 5; i++ { - t = uapi.xor(bc[(i+4)%5], uapi.lrot(bc[(i+1)%5], 1)) + t = uapi.Xor(bc[(i+4)%5], uapi.Lrot(bc[(i+1)%5], 1)) for j := 0; j < 25; j += 5 { - st[j+i] = uapi.xor(st[j+i], t) + st[j+i] = uapi.Xor(st[j+i], t) } } // rho pi @@ -86,7 +89,7 @@ func permute(api frontend.API, st [25]xuint64) [25]xuint64 { for i := 0; i < 24; i++ { j := piln[i] bc[0] = st[j] - st[j] = uapi.lrot(t, rotc[i]) + st[j] = uapi.Lrot(t, rotc[i]) t = bc[0] } @@ -96,11 +99,11 @@ func permute(api frontend.API, st [25]xuint64) [25]xuint64 { bc[i] = st[j+i] } for i := 0; i < 5; i++ { - st[j+i] = uapi.xor(st[j+i], uapi.and(uapi.not(bc[(i+1)%5]), bc[(i+2)%5])) + st[j+i] = uapi.Xor(st[j+i], uapi.And(uapi.Not(bc[(i+1)%5]), bc[(i+2)%5])) } } // iota - st[0] = uapi.xor(st[0], rc[r]) + st[0] = uapi.Xor(st[0], rc[r]) } return st } diff --git a/std/permutation/keccakf/uint64api.go b/std/permutation/keccakf/uint64api.go deleted file mode 100644 index c7c2246bb2..0000000000 --- a/std/permutation/keccakf/uint64api.go +++ /dev/null @@ -1,101 +0,0 @@ -package keccakf - -import ( - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/bits" -) - -// uint64api performs binary operations on xuint64 variables. In the -// future possibly using lookup tables. -// -// TODO: we could possibly optimise using hints if working over many inputs. For -// example, if we OR many bits, then the result is 0 if the sum of the bits is -// larger than 1. And AND is 1 if the sum of bits is the number of inputs. BUt -// this probably helps only if we have a lot of similar operations in a row -// (more than 4). We could probably unroll the whole permutation and expand all -// the formulas to see. But long term tables are still better. -type uint64api struct { - api frontend.API -} - -func newUint64API(api frontend.API) *uint64api { - return &uint64api{ - api: api, - } -} - -// varUint64 represents 64-bit unsigned integer. We use this type to ensure that -// we work over constrained bits. Do not initialize directly, use [wideBinaryOpsApi.asUint64]. -type xuint64 [64]frontend.Variable - -func constUint64(a uint64) xuint64 { - var res xuint64 - for i := 0; i < 64; i++ { - res[i] = (a >> i) & 1 - } - return res -} - -func (w *uint64api) asUint64(in frontend.Variable) xuint64 { - bits := bits.ToBinary(w.api, in, bits.WithNbDigits(64)) - var res xuint64 - copy(res[:], bits) - return res -} - -func (w *uint64api) fromUint64(in xuint64) frontend.Variable { - return bits.FromBinary(w.api, in[:], bits.WithUnconstrainedInputs()) -} - -func (w *uint64api) and(in ...xuint64) xuint64 { - var res xuint64 - for i := range res { - res[i] = 1 - } - for i := range res { - for _, v := range in { - res[i] = w.api.And(res[i], v[i]) - } - } - return res -} - -func (w *uint64api) xor(in ...xuint64) xuint64 { - var res xuint64 - for i := range res { - res[i] = 0 - } - for i := range res { - for _, v := range in { - res[i] = w.api.Xor(res[i], v[i]) - } - } - return res -} - -func (w *uint64api) lrot(in xuint64, shift int) xuint64 { - var res xuint64 - for i := range res { - res[i] = in[(i-shift+64)%64] - } - return res -} - -func (w *uint64api) not(in xuint64) xuint64 { - // TODO: it would be better to have separate method for it. If we have - // native API support, then in R1CS would be free (1-X) and in PLONK 1 - // constraint (1-X). But if we do XOR, then we always have a constraint with - // R1CS (not sure if 1-2 with PLONK). If we do 1-X ourselves, then compiler - // marks as binary which is 1-2 (R1CS-PLONK). - var res xuint64 - for i := range res { - res[i] = w.api.Xor(in[i], 1) - } - return res -} - -func (w *uint64api) assertEq(a, b xuint64) { - for i := range a { - w.api.AssertIsEqual(a[i], b[i]) - } -} diff --git a/std/permutation/keccakf/uint64api_test.go b/std/permutation/keccakf/uint64api_test.go deleted file mode 100644 index 0e39794695..0000000000 --- a/std/permutation/keccakf/uint64api_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package keccakf - -import ( - "testing" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/test" -) - -type lrotCirc struct { - In frontend.Variable - Shift int - Out frontend.Variable -} - -func (c *lrotCirc) Define(api frontend.API) error { - uapi := newUint64API(api) - in := uapi.asUint64(c.In) - out := uapi.asUint64(c.Out) - res := uapi.lrot(in, c.Shift) - uapi.assertEq(out, res) - return nil -} - -func TestLeftRotation(t *testing.T) { - assert := test.NewAssert(t) - // err := test.IsSolved(&lrotCirc{Shift: 2}, &lrotCirc{In: 6, Shift: 2, Out: 24}, ecc.BN254.ScalarField()) - // assert.NoError(err) - assert.ProverSucceeded(&lrotCirc{Shift: 2}, &lrotCirc{In: 6, Shift: 2, Out: 24}) -} diff --git a/std/permutation/sha2/sha2block.go b/std/permutation/sha2/sha2block.go new file mode 100644 index 0000000000..a3991230b3 --- /dev/null +++ b/std/permutation/sha2/sha2block.go @@ -0,0 +1,90 @@ +package sha2 + +import ( + "github.com/consensys/gnark/std/math/uints" +) + +var _K = uints.NewU32Array([]uint32{ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +}) + +func Permute(uapi *uints.BinaryField[uints.U32], currentHash [8]uints.U32, p [64]uints.U8) (newHash [8]uints.U32) { + var w [64]uints.U32 + + for i := 0; i < 16; i++ { + w[i] = uapi.PackMSB(p[4*i], p[4*i+1], p[4*i+2], p[4*i+3]) + } + + for i := 16; i < 64; i++ { + v1 := w[i-2] + t1 := uapi.Xor( + uapi.Lrot(v1, -17), + uapi.Lrot(v1, -19), + uapi.Rshift(v1, 10), + ) + v2 := w[i-15] + t2 := uapi.Xor( + uapi.Lrot(v2, -7), + uapi.Lrot(v2, -18), + uapi.Rshift(v2, 3), + ) + + w[i] = uapi.Add(t1, w[i-7], t2, w[i-16]) + } + + a, b, c, d, e, f, g, h := currentHash[0], currentHash[1], currentHash[2], currentHash[3], currentHash[4], currentHash[5], currentHash[6], currentHash[7] + + for i := 0; i < 64; i++ { + t1 := uapi.Add( + h, + uapi.Xor( + uapi.Lrot(e, -6), + uapi.Lrot(e, -11), + uapi.Lrot(e, -25)), + uapi.Xor( + uapi.And(e, f), + uapi.And( + uapi.Not(e), + g)), + _K[i], + w[i], + ) + t2 := uapi.Add( + uapi.Xor( + uapi.Lrot(a, -2), + uapi.Lrot(a, -13), + uapi.Lrot(a, -22)), + uapi.Xor( + uapi.And(a, b), + uapi.And(a, c), + uapi.And(b, c)), + ) + + h = g + g = f + f = e + e = uapi.Add(d, t1) + d = c + c = b + b = a + a = uapi.Add(t1, t2) + } + + currentHash[0] = uapi.Add(currentHash[0], a) + currentHash[1] = uapi.Add(currentHash[1], b) + currentHash[2] = uapi.Add(currentHash[2], c) + currentHash[3] = uapi.Add(currentHash[3], d) + currentHash[4] = uapi.Add(currentHash[4], e) + currentHash[5] = uapi.Add(currentHash[5], f) + currentHash[6] = uapi.Add(currentHash[6], g) + currentHash[7] = uapi.Add(currentHash[7], h) + + return currentHash +} diff --git a/std/permutation/sha2/sha2block_test.go b/std/permutation/sha2/sha2block_test.go new file mode 100644 index 0000000000..b634ab6624 --- /dev/null +++ b/std/permutation/sha2/sha2block_test.go @@ -0,0 +1,116 @@ +package sha2_test + +import ( + "math/bits" + "math/rand" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/std/permutation/sha2" + "github.com/consensys/gnark/test" +) + +var _K = []uint32{ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +} + +const ( + chunk = 64 +) + +type digest struct { + h [8]uint32 +} + +func blockGeneric(dig *digest, p []byte) { + var w [64]uint32 + h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] + for len(p) >= chunk { + // Can interlace the computation of w with the + // rounds below if needed for speed. + for i := 0; i < 16; i++ { + j := i * 4 + w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3]) + } + for i := 16; i < 64; i++ { + v1 := w[i-2] + t1 := (bits.RotateLeft32(v1, -17)) ^ (bits.RotateLeft32(v1, -19)) ^ (v1 >> 10) + v2 := w[i-15] + t2 := (bits.RotateLeft32(v2, -7)) ^ (bits.RotateLeft32(v2, -18)) ^ (v2 >> 3) + w[i] = t1 + w[i-7] + t2 + w[i-16] + } + + a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7 + + for i := 0; i < 64; i++ { + t1 := h + ((bits.RotateLeft32(e, -6)) ^ (bits.RotateLeft32(e, -11)) ^ (bits.RotateLeft32(e, -25))) + ((e & f) ^ (^e & g)) + _K[i] + w[i] + + t2 := ((bits.RotateLeft32(a, -2)) ^ (bits.RotateLeft32(a, -13)) ^ (bits.RotateLeft32(a, -22))) + ((a & b) ^ (a & c) ^ (b & c)) + + h, g, f, e, d, c, b, a = g, f, e, d+t1, c, b, a, t1+t2 + } + + h0 += a + h1 += b + h2 += c + h3 += d + h4 += e + h5 += f + h6 += g + h7 += h + + p = p[chunk:] + } + + dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] = h0, h1, h2, h3, h4, h5, h6, h7 +} + +type circuitBlock struct { + CurrentDig [8]uints.U32 + In [64]uints.U8 + Expected [8]uints.U32 +} + +func (c *circuitBlock) Define(api frontend.API) error { + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + res := sha2.Permute(uapi, c.CurrentDig, c.In) + for i := range c.Expected { + uapi.AssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestBlockGeneric(t *testing.T) { + assert := test.NewAssert(t) + s := rand.New(rand.NewSource(time.Now().Unix())) //nolint G404, test code + witness := circuitBlock{} + dig := digest{} + var in [chunk]byte + for i := range dig.h { + dig.h[i] = s.Uint32() + witness.CurrentDig[i] = uints.NewU32(dig.h[i]) + } + for i := range in { + in[i] = byte(s.Uint32() & 0xff) + witness.In[i] = uints.NewU8(in[i]) + } + blockGeneric(&dig, in[:]) + for i := range dig.h { + witness.Expected[i] = uints.NewU32(dig.h[i]) + } + err := test.IsSolved(&circuitBlock{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/polynomial/polynomial.go b/std/polynomial/polynomial.go index 72cc329f9f..0953cb3ac7 100644 --- a/std/polynomial/polynomial.go +++ b/std/polynomial/polynomial.go @@ -9,30 +9,71 @@ import ( type Polynomial []frontend.Variable type MultiLin []frontend.Variable +var minFoldScaledLogSize = 16 + // Evaluate assumes len(m) = 1 << len(at) +// it doesn't modify m func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable { - - eqs := make([]frontend.Variable, len(m)) - eqs[0] = 1 - for i, rI := range at { - prevSize := 1 << i - for j := prevSize - 1; j >= 0; j-- { - eqs[2*j+1] = api.Mul(rI, eqs[j]) - eqs[2*j] = api.Sub(eqs[j], eqs[2*j+1]) // eq[2j] == (1 - rI) * eq[j] + _m := m.Clone() + + /*minFoldScaledLogSize := 16 + if api is r1cs { + minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs + }*/ + + scaleCorrectionFactor := frontend.Variable(1) + // at each iteration fold by at[i] + for len(_m) > 1 { + if len(_m) >= minFoldScaledLogSize { + scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0])) + } else { + _m.fold(api, at[0]) } + _m = _m[:len(_m)/2] + at = at[1:] + } + + if len(at) != 0 { + panic("incompatible evaluation vector size") + } + + return api.Mul(_m[0], scaleCorrectionFactor) +} + +// fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size +// WARNING: The user should halve m themselves after the call +func (m MultiLin) fold(api frontend.API, at frontend.Variable) { + zero := m[:len(m)/2] + one := m[len(m)/2:] + for j := range zero { + diff := api.Sub(one[j], zero[j]) + zero[j] = api.MulAcc(zero[j], diff, at) } +} - evaluation := frontend.Variable(0) - for j := range m { - evaluation = api.MulAcc(evaluation, eqs[j], m[j]) +// foldScaled(m, at) = fold(m, at) / (1 - at) +// it returns 1 - at, for convenience +func (m MultiLin) foldScaled(api frontend.API, at frontend.Variable) (denom frontend.Variable) { + denom = api.Sub(1, at) + coeff := api.Div(at, denom) + zero := m[:len(m)/2] + one := m[len(m)/2:] + for j := range zero { + zero[j] = api.MulAcc(zero[j], one[j], coeff) } - return evaluation + return } func (m MultiLin) NumVars() int { return bits.TrailingZeros(uint(len(m))) } +func (m MultiLin) Clone() MultiLin { + clone := make(MultiLin, len(m)) + copy(clone, m) + return clone +} + func (p Polynomial) Eval(api frontend.API, at frontend.Variable) (pAt frontend.Variable) { pAt = 0 diff --git a/std/polynomial/polynomial_test.go b/std/polynomial/polynomial_test.go index 92656d779a..24b706a3f6 100644 --- a/std/polynomial/polynomial_test.go +++ b/std/polynomial/polynomial_test.go @@ -1,10 +1,13 @@ package polynomial import ( + "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" "testing" ) @@ -70,6 +73,31 @@ func TestEvalDeltasQuadratic(t *testing.T) { testEvalDeltas(t, 3, []int64{1, -3, 3}) } +type foldMultiLinCircuit struct { + M []frontend.Variable + At frontend.Variable + Result []frontend.Variable +} + +func (c *foldMultiLinCircuit) Define(api frontend.API) error { + if len(c.M) != 2*len(c.Result) { + return errors.New("folding size mismatch") + } + m := MultiLin(c.M) + m.fold(api, c.At) + for i := range c.Result { + api.AssertIsEqual(m[i], c.Result[i]) + } + return nil +} + +func TestFoldSmall(t *testing.T) { + test.NewAssert(t).SolvingSucceeded( + &foldMultiLinCircuit{M: make([]frontend.Variable, 4), Result: make([]frontend.Variable, 2)}, + &foldMultiLinCircuit{M: []frontend.Variable{0, 1, 2, 3}, At: 2, Result: []frontend.Variable{4, 5}}, + ) +} + type evalMultiLinCircuit struct { M []frontend.Variable `gnark:",public"` At []frontend.Variable `gnark:",secret"` @@ -204,3 +232,25 @@ func int64SliceToVariableSlice(slice []int64) []frontend.Variable { } return res } + +func ExampleMultiLin_Evaluate() { + const logSize = 20 + const size = 1 << logSize + m := MultiLin(make([]frontend.Variable, size)) + e := MultiLin(make([]frontend.Variable, logSize)) + + cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &evalMultiLinCircuit{M: m, At: e, Evaluation: 0}) + if err != nil { + panic(err) + } + fmt.Println("r1cs size:", cs.GetNbConstraints()) + + cs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &evalMultiLinCircuit{M: m, At: e, Evaluation: 0}) + if err != nil { + panic(err) + } + fmt.Println("scs size:", cs.GetNbConstraints()) + + // Output: r1cs size: 1048627 + //scs size: 2097226 +} diff --git a/std/rangecheck/rangecheck.go b/std/rangecheck/rangecheck.go new file mode 100644 index 0000000000..8aac734dcb --- /dev/null +++ b/std/rangecheck/rangecheck.go @@ -0,0 +1,37 @@ +// Package rangecheck implements range checking gadget +// +// This package chooses the most optimal path for performing range checks: +// - if the backend supports native range checking and the frontend exports the variables in the proprietary format by implementing [frontend.Rangechecker], then use it directly; +// - if the backend supports creating a commitment of variables by implementing [frontend.Committer], then we use the log-derivative variant [[Haböck22]] of the product argument as in [[BCG+18]] . [r1cs.NewBuilder] returns a builder which implements this interface; +// - lacking these, we perform binary decomposition of variable into bits. +// +// [BCG+18]: https://eprint.iacr.org/2018/380 +// [Haböck22]: https://eprint.iacr.org/2022/1530 +package rangecheck + +import ( + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +// only for documentation purposes. If we import the package then godoc knows +// how to refer to package r1cs and we get nice links in godoc. We import the +// package anyway in test. +var _ = r1cs.NewBuilder + +// New returns a new range checker depending on the frontend capabilities. +func New(api frontend.API) frontend.Rangechecker { + if rc, ok := api.(frontend.Rangechecker); ok { + return rc + } + if _, ok := api.(frontend.Committer); ok { + return newCommitRangechecker(api) + } + return plainChecker{api: api} +} + +// GetHints returns all hints used in this package +func GetHints() []solver.Hint { + return []solver.Hint{DecomposeHint} +} diff --git a/std/rangecheck/rangecheck_commit.go b/std/rangecheck/rangecheck_commit.go new file mode 100644 index 0000000000..6af6eb5be6 --- /dev/null +++ b/std/rangecheck/rangecheck_commit.go @@ -0,0 +1,175 @@ +package rangecheck + +import ( + "fmt" + "math" + "math/big" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/frontendtype" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/internal/logderivarg" +) + +type ctxCheckerKey struct{} + +func init() { + solver.RegisterHint(DecomposeHint) +} + +type checkedVariable struct { + v frontend.Variable + bits int +} + +type commitChecker struct { + collected []checkedVariable + closed bool +} + +func newCommitRangechecker(api frontend.API) *commitChecker { + kv, ok := api.Compiler().(kvstore.Store) + if !ok { + panic("builder should implement key-value store") + } + ch := kv.GetKeyValue(ctxCheckerKey{}) + if ch != nil { + if cht, ok := ch.(*commitChecker); ok { + return cht + } else { + panic("stored rangechecker is not valid") + } + } + cht := &commitChecker{} + kv.SetKeyValue(ctxCheckerKey{}, cht) + api.Compiler().Defer(cht.commit) + return cht +} + +func (c *commitChecker) Check(in frontend.Variable, bits int) { + if c.closed { + panic("checker already closed") + } + c.collected = append(c.collected, checkedVariable{v: in, bits: bits}) +} + +func (c *commitChecker) buildTable(nbTable int) []frontend.Variable { + tbl := make([]frontend.Variable, nbTable) + for i := 0; i < nbTable; i++ { + tbl[i] = i + } + return tbl +} + +func (c *commitChecker) commit(api frontend.API) error { + if c.closed { + return nil + } + defer func() { c.closed = true }() + if len(c.collected) == 0 { + return nil + } + baseLength := c.getOptimalBasewidth(api) + // decompose into smaller limbs + decomposed := make([]frontend.Variable, 0, len(c.collected)) + collected := make([]frontend.Variable, len(c.collected)) + base := new(big.Int).Lsh(big.NewInt(1), uint(baseLength)) + for i := range c.collected { + // collect all vars for commitment input + collected[i] = c.collected[i].v + // decompose value into limbs + nbLimbs := decompSize(c.collected[i].bits, baseLength) + limbs, err := api.Compiler().NewHint(DecomposeHint, int(nbLimbs), c.collected[i].bits, baseLength, c.collected[i].v) + if err != nil { + panic(fmt.Sprintf("decompose %v", err)) + } + // store all limbs for counting + decomposed = append(decomposed, limbs...) + // check that limbs are correct. We check the sizes of the limbs later + var composed frontend.Variable = 0 + for j := range limbs { + composed = api.Add(composed, api.Mul(limbs[j], new(big.Int).Exp(base, big.NewInt(int64(j)), nil))) + } + api.AssertIsEqual(composed, c.collected[i].v) + } + nbTable := 1 << baseLength + return logderivarg.Build(api, logderivarg.AsTable(c.buildTable(nbTable)), logderivarg.AsTable(decomposed)) +} + +func decompSize(varSize int, limbSize int) int { + return (varSize + limbSize - 1) / limbSize +} + +// DecomposeHint is a hint used for range checking with commitment. It +// decomposes large variables into chunks which can be individually range-check +// in the native range. +func DecomposeHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 3 { + return fmt.Errorf("input must be 3 elements") + } + if !inputs[0].IsUint64() || !inputs[1].IsUint64() { + return fmt.Errorf("first two inputs have to be uint64") + } + varSize := int(inputs[0].Int64()) + limbSize := int(inputs[1].Int64()) + val := inputs[2] + nbLimbs := decompSize(varSize, limbSize) + if len(outputs) != nbLimbs { + return fmt.Errorf("need %d outputs instead to decompose", nbLimbs) + } + base := new(big.Int).Lsh(big.NewInt(1), uint(limbSize)) + tmp := new(big.Int).Set(val) + for i := 0; i < len(outputs); i++ { + outputs[i].Mod(tmp, base) + tmp.Rsh(tmp, uint(limbSize)) + } + return nil +} + +func (c *commitChecker) getOptimalBasewidth(api frontend.API) int { + if ft, ok := api.(frontendtype.FrontendTyper); ok { + switch ft.FrontendType() { + case frontendtype.R1CS: + return optimalWidth(nbR1CSConstraints, c.collected) + case frontendtype.SCS: + return optimalWidth(nbPLONKConstraints, c.collected) + } + } + return optimalWidth(nbR1CSConstraints, c.collected) +} + +func optimalWidth(countFn func(baseLength int, collected []checkedVariable) int, collected []checkedVariable) int { + min := math.MaxInt64 + minVal := 0 + for j := 2; j < 18; j++ { + current := countFn(j, collected) + if current < min { + min = current + minVal = j + } + } + return minVal +} + +func nbR1CSConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := len(collected) // correctness of decomposition + nbRight := nbDecomposed // inverse per decomposed + nbleft := (1 << baseLength) // div per table + return nbleft + nbRight + eqs + 1 +} + +func nbPLONKConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := nbDecomposed // check correctness of every decomposition. this is nbDecomp adds + eq cost per collected + nbRight := 3 * nbDecomposed // denominator sub, inv and large sum per table entry + nbleft := 3 * (1 << baseLength) // denominator sub, div and large sum per table entry + return nbleft + nbRight + eqs + 1 // and the final assert +} diff --git a/std/rangecheck/rangecheck_plain.go b/std/rangecheck/rangecheck_plain.go new file mode 100644 index 0000000000..6f20418f2d --- /dev/null +++ b/std/rangecheck/rangecheck_plain.go @@ -0,0 +1,14 @@ +package rangecheck + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/bits" +) + +type plainChecker struct { + api frontend.API +} + +func (pl plainChecker) Check(v frontend.Variable, nbBits int) { + bits.ToBinary(pl.api, v, bits.WithNbDigits(nbBits)) +} diff --git a/std/rangecheck/rangecheck_test.go b/std/rangecheck/rangecheck_test.go new file mode 100644 index 0000000000..42a26827d8 --- /dev/null +++ b/std/rangecheck/rangecheck_test.go @@ -0,0 +1,46 @@ +package rangecheck + +import ( + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type CheckCircuit struct { + Vals []frontend.Variable + bits int +} + +func (c *CheckCircuit) Define(api frontend.API) error { + r := newCommitRangechecker(api) + for i := range c.Vals { + r.Check(c.Vals[i], c.bits) + } + return nil +} + +func TestCheck(t *testing.T) { + assert := test.NewAssert(t) + var err error + bits := 64 + nbVals := 100000 + bound := new(big.Int).Lsh(big.NewInt(1), uint(bits)) + vals := make([]frontend.Variable, nbVals) + for i := range vals { + vals[i], err = rand.Int(rand.Reader, bound) + if err != nil { + t.Fatal(err) + } + } + witness := CheckCircuit{Vals: vals, bits: bits} + circuit := CheckCircuit{Vals: make([]frontend.Variable, len(vals)), bits: bits} + err = test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + _, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithCompressThreshold(100)) + assert.NoError(err) +} diff --git a/std/selector/doc_map_test.go b/std/selector/doc_map_test.go new file mode 100644 index 0000000000..f70c353f61 --- /dev/null +++ b/std/selector/doc_map_test.go @@ -0,0 +1,79 @@ +package selector_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// MapCircuit is a minimal circuit using a selector map. +type MapCircuit struct { + QueryKey frontend.Variable + Keys [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + Values [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + ExpectedValue frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *MapCircuit) Define(api frontend.API) error { + result := selector.Map(api, c.QueryKey, c.Keys[:], c.Values[:]) + api.AssertIsEqual(result, c.ExpectedValue) + return nil +} + +// ExampleMap gives an example on how to use map selector. +func ExampleMap() { + circuit := MapCircuit{} + witness := MapCircuit{ + QueryKey: 55, + Keys: [10]frontend.Variable{0, 11, 22, 33, 44, 55, 66, 77, 88, 99}, + Values: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + ExpectedValue: 10, // element in values which corresponds to the position of 55 in keys + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/doc_mux_test.go b/std/selector/doc_mux_test.go new file mode 100644 index 0000000000..5a7b75ba36 --- /dev/null +++ b/std/selector/doc_mux_test.go @@ -0,0 +1,77 @@ +package selector_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// MuxCircuit is a minimal circuit using a selector mux. +type MuxCircuit struct { + Selector frontend.Variable + In [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + Expected frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *MuxCircuit) Define(api frontend.API) error { + result := selector.Mux(api, c.Selector, c.In[:]...) // Note Mux takes var-arg input, need to expand the input vector + api.AssertIsEqual(result, c.Expected) + return nil +} + +// ExampleMux gives an example on how to use mux selector. +func ExampleMux() { + circuit := MuxCircuit{} + witness := MuxCircuit{ + Selector: 5, + In: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + Expected: 10, // 5-th element in vector + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/doc_partition_test.go b/std/selector/doc_partition_test.go new file mode 100644 index 0000000000..2a9b4ea548 --- /dev/null +++ b/std/selector/doc_partition_test.go @@ -0,0 +1,80 @@ +package selector_test + +import "github.com/consensys/gnark/frontend" + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// adderCircuit adds first Count number of its input array In. +type adderCircuit struct { + Count frontend.Variable + In [10]frontend.Variable + ExpectedSum frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *adderCircuit) Define(api frontend.API) error { + selectedPart := selector.Partition(api, c.Count, false, c.In[:]) + sum := api.Add(selectedPart[0], selectedPart[1], selectedPart[2:]...) + api.AssertIsEqual(sum, c.ExpectedSum) + return nil +} + +// ExamplePartition gives an example on how to use selector.Partition to make a circuit that accepts a Count and an +// input array In, and then calculates the sum of first Count numbers of the input array. +func ExamplePartition() { + circuit := adderCircuit{} + witness := adderCircuit{ + Count: 6, + In: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + ExpectedSum: 30, + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/multiplexer.go b/std/selector/multiplexer.go new file mode 100644 index 0000000000..5f9b0d7114 --- /dev/null +++ b/std/selector/multiplexer.go @@ -0,0 +1,169 @@ +// Package selector provides a lookup table and map, based on linear scan. +// +// The native [frontend.API] provides 1- and 2-bit lookups through the interface +// methods Select and Lookup2. This package extends the lookups to +// arbitrary-sized vectors. The lookups can be performed using the index of the +// elements (function [Mux]) or using a key, for which the user needs to provide +// the slice of keys (function [Map]). +// +// The implementation uses linear scan over all inputs. +package selector + +import ( + "fmt" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/bits" + "math/big" + binary "math/bits" +) + +func init() { + // register hints + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in this package. This method is +// useful for registering all hints in the solver. +func GetHints() []solver.Hint { + return []solver.Hint{stepOutput, muxIndicators, mapIndicators} +} + +// Map is a key value associative array: the output will be values[i] such that +// keys[i] == queryKey. If keys does not contain queryKey, no proofs can be +// generated. If keys has more than one key that equals to queryKey, the output +// will be undefined, and the output could be any linear combination of all the +// corresponding values with that queryKey. +// +// In case keys and values do not have the same length, this function will +// panic. +func Map(api frontend.API, queryKey frontend.Variable, + keys []frontend.Variable, values []frontend.Variable) frontend.Variable { + // we don't need this check, but we added it to produce more informative errors + // and disallow len(keys) < len(values) which is supported by generateSelector. + if len(keys) != len(values) { + panic(fmt.Sprintf("The number of keys and values must be equal (%d != %d)", len(keys), len(values))) + } + return dotProduct(api, values, KeyDecoder(api, queryKey, keys)) +} + +// Mux is an n to 1 multiplexer: out = inputs[sel]. In other words, it selects +// exactly one of its inputs based on sel. The index of inputs starts from zero. +// +// sel needs to be between 0 and n - 1 (inclusive), where n is the number of +// inputs, otherwise the proof will fail. +func Mux(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable) frontend.Variable { + // we use BinaryMux when len(inputs) is a power of 2. + if binary.OnesCount(uint(len(inputs))) == 1 { + selBits := bits.ToBinary(api, sel, bits.WithNbDigits(binary.Len(uint(len(inputs)))-1)) + return BinaryMux(api, selBits, inputs) + } + return dotProduct(api, inputs, Decoder(api, len(inputs), sel)) +} + +// KeyDecoder is a decoder that associates keys to its output wires. It outputs +// 1 on the wire that is associated to a key that equals to queryKey. In other +// words: +// +// if keys[i] == queryKey +// out[i] = 1 +// else +// out[i] = 0 +// +// If keys has more than one key that equals to queryKey, the output is +// undefined. However, the output is guaranteed to be zero for the wires that +// are associated with a key which is not equal to queryKey. +func KeyDecoder(api frontend.API, queryKey frontend.Variable, keys []frontend.Variable) []frontend.Variable { + return generateDecoder(api, false, 0, queryKey, keys) +} + +// Decoder is a decoder with n outputs. It outputs 1 on the wire with index sel, +// and 0 otherwise. Indices start from zero. In other words: +// +// if i == sel +// out[i] = 1 +// else +// out[i] = 0 +// +// sel needs to be between 0 and n - 1 (inclusive) otherwise no proof can be +// generated. +func Decoder(api frontend.API, n int, sel frontend.Variable) []frontend.Variable { + return generateDecoder(api, true, n, sel, nil) +} + +// generateDecoder generates a circuit for a decoder which indicates the +// selected index. If sequential is true, an ordinary decoder of size n is +// generated, and keys are ignored. If sequential is false, a key based decoder +// is generated, and len(keys) is used to determine the size of the output. n +// will be ignored in this case. +func generateDecoder(api frontend.API, sequential bool, n int, sel frontend.Variable, + keys []frontend.Variable) []frontend.Variable { + + var indicators []frontend.Variable + var err error + if sequential { + indicators, err = api.Compiler().NewHint(muxIndicators, n, sel) + } else { + indicators, err = api.Compiler().NewHint(mapIndicators, len(keys), append(keys, sel)...) + } + if err != nil { + panic(fmt.Sprintf("error in calling Mux/Map hint: %v", err)) + } + + indicatorsSum := frontend.Variable(0) + for i := 0; i < len(indicators); i++ { + // Check that all indicators for inputs that are not selected, are zero. + if sequential { + // indicators[i] * (sel - i) == 0 + api.AssertIsEqual(api.Mul(indicators[i], api.Sub(sel, i)), 0) + } else { + // indicators[i] * (sel - keys[i]) == 0 + api.AssertIsEqual(api.Mul(indicators[i], api.Sub(sel, keys[i])), 0) + } + indicatorsSum = api.Add(indicatorsSum, indicators[i]) + } + // We need to check that the indicator of the selected input is exactly 1. We + // use a sum constraint, because usually it is cheap. + api.AssertIsEqual(indicatorsSum, 1) + return indicators +} + +func dotProduct(api frontend.API, a, b []frontend.Variable) frontend.Variable { + out := frontend.Variable(0) + for i := 0; i < len(a); i++ { + // out += indicators[i] * values[i] + out = api.MulAcc(out, a[i], b[i]) + } + return out +} + +// muxIndicators is a hint function used within [Mux] function. It must be +// provided to the prover when circuit uses it. +func muxIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + sel := inputs[0] + for i := 0; i < len(results); i++ { + // i is an int which can be int32 or int64. We convert i to int64 then to + // bigInt, which is safe. We should not convert sel to int64. + if sel.Cmp(big.NewInt(int64(i))) == 0 { + results[i].SetUint64(1) + } else { + results[i].SetUint64(0) + } + } + return nil +} + +// mapIndicators is a hint function used within [Map] function. It must be +// provided to the prover when circuit uses it. +func mapIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + key := inputs[len(inputs)-1] + // We must make sure that we are initializing all elements of results + for i := 0; i < len(results); i++ { + if key.Cmp(inputs[i]) == 0 { + results[i].SetUint64(1) + } else { + results[i].SetUint64(0) + } + } + return nil +} diff --git a/std/selector/multiplexer_test.go b/std/selector/multiplexer_test.go new file mode 100644 index 0000000000..2c01f08a01 --- /dev/null +++ b/std/selector/multiplexer_test.go @@ -0,0 +1,240 @@ +package selector_test + +import ( + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend/cs/r1cs" + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/selector" + "github.com/consensys/gnark/test" +) + +type muxCircuit struct { + SEL frontend.Variable + I0, I1, I2, I3, I4 frontend.Variable + OUT frontend.Variable +} + +func (c *muxCircuit) Define(api frontend.API) error { + + out := selector.Mux(api, c.SEL, c.I0, c.I1, c.I2, c.I3, c.I4) + + api.AssertIsEqual(out, c.OUT) + + return nil +} + +// The output of this circuit is ignored and the only way its proof can fail is by providing invalid inputs. +type ignoredOutputMuxCircuit struct { + SEL frontend.Variable + I0, I1, I2 frontend.Variable +} + +func (c *ignoredOutputMuxCircuit) Define(api frontend.API) error { + // We ignore the output + _ = selector.Mux(api, c.SEL, c.I0, c.I1, c.I2) + + return nil +} + +type mux2to1Circuit struct { + SEL frontend.Variable + I0, I1 frontend.Variable + OUT frontend.Variable +} + +func (c *mux2to1Circuit) Define(api frontend.API) error { + // We ignore the output + out := selector.Mux(api, c.SEL, c.I0, c.I1) + api.AssertIsEqual(out, c.OUT) + return nil +} + +type mux4to1Circuit struct { + SEL frontend.Variable + In [4]frontend.Variable + OUT frontend.Variable +} + +func (c *mux4to1Circuit) Define(api frontend.API) error { + out := selector.Mux(api, c.SEL, c.In[:]...) + api.AssertIsEqual(out, c.OUT) + return nil +} + +func TestMux(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 2, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 12}) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 0, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 10}) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 4, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}) + + // Failures + assert.ProverFailed(&muxCircuit{}, &muxCircuit{SEL: 5, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}) + + assert.ProverFailed(&muxCircuit{}, &muxCircuit{SEL: 0, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 21}) + + // Ignoring the circuit's output: + assert.ProverSucceeded(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 0, I0: 0, I1: 1, I2: 2}) + + assert.ProverSucceeded(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 2, I0: 0, I1: 1, I2: 2}) + + // Failures + assert.ProverFailed(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 3, I0: 0, I1: 1, I2: 2}) + + assert.ProverFailed(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: -1, I0: 0, I1: 1, I2: 2}) + + // 2 to 1 mux + assert.ProverSucceeded(&mux2to1Circuit{}, &mux2to1Circuit{SEL: 1, I0: 10, I1: 20, OUT: 20}) + + assert.ProverSucceeded(&mux2to1Circuit{}, &mux2to1Circuit{SEL: 0, I0: 10, I1: 20, OUT: 10}) + + assert.ProverFailed(&mux2to1Circuit{}, &mux2to1Circuit{SEL: 2, I0: 10, I1: 20, OUT: 20}) + + // 4 to 1 mux + assert.ProverSucceeded(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 3, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 44, + }) + + assert.ProverSucceeded(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 1, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 22, + }) + + assert.ProverSucceeded(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 0, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 11, + }) + + assert.ProverFailed(&mux4to1Circuit{}, &mux4to1Circuit{ + SEL: 4, + In: [4]frontend.Variable{11, 22, 33, 44}, + OUT: 44, + }) + + cs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &mux4to1Circuit{}) + // (4 - 1) + (2 + 1) + 1 == 7 + assert.Equal(7, cs.GetNbConstraints()) +} + +// Map tests: +type mapCircuit struct { + SEL frontend.Variable + K0, K1, K2, K3 frontend.Variable + V0, V1, V2, V3 frontend.Variable + OUT frontend.Variable +} + +func (c *mapCircuit) Define(api frontend.API) error { + + out := selector.Map(api, c.SEL, + []frontend.Variable{c.K0, c.K1, c.K2, c.K3}, + []frontend.Variable{c.V0, c.V1, c.V2, c.V3}) + + api.AssertIsEqual(out, c.OUT) + + return nil +} + +type ignoredOutputMapCircuit struct { + SEL frontend.Variable + K0, K1 frontend.Variable + V0, V1 frontend.Variable +} + +func (c *ignoredOutputMapCircuit) Define(api frontend.API) error { + + _ = selector.Map(api, c.SEL, + []frontend.Variable{c.K0, c.K1}, + []frontend.Variable{c.V0, c.V1}) + + return nil +} + +func TestMap(t *testing.T) { + assert := test.NewAssert(t) + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 100, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 0, + }) + + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 222, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 2, + }) + + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + // Duplicated key, success: + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 222, K1: 222, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + // Duplicated key, UNDEFINED behavior: (with our hint implementation it fails) + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 100, K1: 111, K2: 333, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 77, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 111, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 2, + }) + + // Ignoring the circuit's output: + assert.ProverSucceeded(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 5, + K0: 5, K1: 7, + V0: 10, V1: 11, + }) + + assert.ProverFailed(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 5, + K0: 5, K1: 5, + V0: 10, V1: 11, + }) + + assert.ProverFailed(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 6, + K0: 5, K1: 7, + V0: 10, V1: 11, + }) + +} diff --git a/std/selector/mux.go b/std/selector/mux.go new file mode 100644 index 0000000000..d437430353 --- /dev/null +++ b/std/selector/mux.go @@ -0,0 +1,45 @@ +package selector + +import ( + "fmt" + "github.com/consensys/gnark/frontend" +) + +// BinaryMux is a 2^k to 1 multiplexer which uses a binary selector. selBits are +// the selector bits, and the input at the index equal to the binary number +// represented by the selector bits will be selected. More precisely the output +// will be: +// +// inputs[selBits[0]+selBits[1]*(1<<1)+selBits[2]*(1<<2)+...] +// +// len(inputs) must be 2^len(selBits). +func BinaryMux(api frontend.API, selBits, inputs []frontend.Variable) frontend.Variable { + if len(inputs) != 1<= len(inputs) { + return binaryMuxRecursive(api, nextSelBits, inputs) + } + + left := binaryMuxRecursive(api, nextSelBits, inputs[:pivot]) + right := binaryMuxRecursive(api, nextSelBits, inputs[pivot:]) + return api.Add(left, api.Mul(msb, api.Sub(right, left))) +} diff --git a/std/selector/mux_test.go b/std/selector/mux_test.go new file mode 100644 index 0000000000..f590c23f02 --- /dev/null +++ b/std/selector/mux_test.go @@ -0,0 +1,96 @@ +package selector + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + "testing" +) + +type binaryMuxCircuit struct { + Sel [4]frontend.Variable + In [10]frontend.Variable + Out frontend.Variable +} + +func (c *binaryMuxCircuit) Define(api frontend.API) error { + + out := binaryMuxRecursive(api, c.Sel[:], c.In[:]) + + api.AssertIsEqual(out, c.Out) + + return nil +} + +type binary7to1MuxCircuit struct { + Sel [3]frontend.Variable + In [7]frontend.Variable + Out frontend.Variable +} + +func (c *binary7to1MuxCircuit) Define(api frontend.API) error { + + out := binaryMuxRecursive(api, c.Sel[:], c.In[:]) + + api.AssertIsEqual(out, c.Out) + + return nil +} + +func Test_binaryMuxRecursive(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 0, 1, 0}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 144, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 0, 0, 0}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 100, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{1, 0, 0, 1}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 199, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 1, 0, 0}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 122, + }) + + assert.ProverSucceeded(&binaryMuxCircuit{}, &binaryMuxCircuit{ + Sel: [4]frontend.Variable{0, 0, 0, 1}, + In: [10]frontend.Variable{100, 111, 122, 133, 144, 155, 166, 177, 188, 199}, + Out: 188, + }) + + // 7 to 1 + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{0, 0, 1}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 0, + }) + + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{0, 1, 1}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 1, + }) + + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{0, 0, 0}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 5, + }) + + assert.ProverSucceeded(&binary7to1MuxCircuit{}, &binary7to1MuxCircuit{ + Sel: [3]frontend.Variable{1, 0, 1}, + In: [7]frontend.Variable{5, 3, 10, 6, 0, 9, 1}, + Out: 9, + }) +} diff --git a/std/selector/slice.go b/std/selector/slice.go new file mode 100644 index 0000000000..b21be6059a --- /dev/null +++ b/std/selector/slice.go @@ -0,0 +1,105 @@ +package selector + +import ( + "fmt" + "github.com/consensys/gnark/frontend" + "math/big" +) + +// Slice selects a slice of the input array at indices [start, end), and zeroes the array at other +// indices. More precisely, for each i we have: +// +// if i >= start and i < end +// out[i] = input[i] +// else +// out[i] = 0 +// +// We must have start >= 0 and end <= len(input), otherwise a proof cannot be generated. +func Slice(api frontend.API, start, end frontend.Variable, input []frontend.Variable) []frontend.Variable { + // it appears that this is the most efficient implementation. There is also another implementation + // which creates the mask by adding two stepMask outputs, however that would not work correctly when + // end < start. + out := Partition(api, end, false, input) + out = Partition(api, start, true, out) + return out +} + +// Partition selects left or right side of the input array, with respect to the pivotPosition. +// More precisely when rightSide is false, for each i we have: +// +// if i < pivotPosition +// out[i] = input[i] +// else +// out[i] = 0 +// +// and when rightSide is true, for each i we have: +// +// if i >= pivotPosition +// out[i] = input[i] +// else +// out[i] = 0 +// +// We must have pivotPosition >= 0 and pivotPosition <= len(input), otherwise a proof cannot be generated. +func Partition(api frontend.API, pivotPosition frontend.Variable, rightSide bool, + input []frontend.Variable) (out []frontend.Variable) { + out = make([]frontend.Variable, len(input)) + var mask []frontend.Variable + // we create a bit mask to multiply with the input. + if rightSide { + mask = stepMask(api, len(input), pivotPosition, 0, 1) + } else { + mask = stepMask(api, len(input), pivotPosition, 1, 0) + } + for i := 0; i < len(out); i++ { + out[i] = api.Mul(mask[i], input[i]) + } + return +} + +// stepMask generates a step like function into an output array of a given length. +// The output is an array of length outputLen, +// such that its first stepPosition elements are equal to startValue and the remaining elements are equal to +// endValue. Note that outputLen cannot be a circuit variable. +// +// We must have stepPosition >= 0 and stepPosition <= outputLen, otherwise a proof cannot be generated. +// This function panics when outputLen is less than 2. +func stepMask(api frontend.API, outputLen int, + stepPosition, startValue, endValue frontend.Variable) []frontend.Variable { + if outputLen < 2 { + panic("the output len of StepMask must be >= 2") + } + // get the output as a hint + out, err := api.Compiler().NewHint(stepOutput, outputLen, stepPosition, startValue, endValue) + if err != nil { + panic(fmt.Sprintf("error in calling StepMask hint: %v", err)) + } + + // add the boundary constraints: + // (out[0] - startValue) * stepPosition == 0 + api.AssertIsEqual(api.Mul(api.Sub(out[0], startValue), stepPosition), 0) + // (out[len(out)-1] - endValue) * (len(out) - stepPosition) == 0 + api.AssertIsEqual(api.Mul(api.Sub(out[len(out)-1], endValue), api.Sub(len(out), stepPosition)), 0) + + // add constraints for the correct form of a step function that steps at the stepPosition + for i := 1; i < len(out); i++ { + // (out[i] - out[i-1]) * (i - stepPosition) == 0 + api.AssertIsEqual(api.Mul(api.Sub(out[i], out[i-1]), api.Sub(i, stepPosition)), 0) + } + return out +} + +// stepOutput is a hint function used within [StepMask] function. It must be +// provided to the prover when circuit uses it. +func stepOutput(_ *big.Int, inputs, results []*big.Int) error { + stepPos := inputs[0] + startValue := inputs[1] + endValue := inputs[2] + for i := 0; i < len(results); i++ { + if i < int(stepPos.Int64()) { + results[i].Set(startValue) + } else { + results[i].Set(endValue) + } + } + return nil +} diff --git a/std/selector/slice_test.go b/std/selector/slice_test.go new file mode 100644 index 0000000000..35715c886e --- /dev/null +++ b/std/selector/slice_test.go @@ -0,0 +1,205 @@ +package selector_test + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/selector" + "github.com/consensys/gnark/test" + "testing" +) + +type partitionerCircuit struct { + Pivot frontend.Variable `gnark:",public"` + In [6]frontend.Variable `gnark:",public"` + WantLeft [6]frontend.Variable `gnark:",public"` + WantRight [6]frontend.Variable `gnark:",public"` +} + +func (c *partitionerCircuit) Define(api frontend.API) error { + gotLeft := selector.Partition(api, c.Pivot, false, c.In[:]) + for i, want := range c.WantLeft { + api.AssertIsEqual(gotLeft[i], want) + } + + gotRight := selector.Partition(api, c.Pivot, true, c.In[:]) + for i, want := range c.WantRight { + api.AssertIsEqual(gotRight[i], want) + } + + return nil +} + +type ignoredOutputPartitionerCircuit struct { + Pivot frontend.Variable `gnark:",public"` + In [2]frontend.Variable `gnark:",public"` +} + +func (c *ignoredOutputPartitionerCircuit) Define(api frontend.API) error { + _ = selector.Partition(api, c.Pivot, false, c.In[:]) + _ = selector.Partition(api, c.Pivot, true, c.In[:]) + return nil +} + +func TestPartition(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 3, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 0, 0, 0}, + WantRight: [6]frontend.Variable{0, 0, 0, 40, 50, 60}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 1, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 0, 0, 0, 0, 0}, + WantRight: [6]frontend.Variable{0, 20, 30, 40, 50, 60}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 5, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 50, 0}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 60}, + }) + + assert.ProverFailed(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 5, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 0, 0}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 6, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 0, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + WantRight: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + }) + + // Pivot is outside and the prover fails: + assert.ProverFailed(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: 7, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantRight: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + }) + + assert.ProverFailed(&partitionerCircuit{}, &partitionerCircuit{ + Pivot: -1, + In: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + WantLeft: [6]frontend.Variable{0, 0, 0, 0, 0, 0}, + WantRight: [6]frontend.Variable{10, 20, 30, 40, 50, 60}, + }) + + // tests by ignoring the output: + assert.ProverSucceeded(&ignoredOutputPartitionerCircuit{}, &ignoredOutputPartitionerCircuit{ + Pivot: 1, + In: [2]frontend.Variable{10, 20}, + }) + + assert.ProverFailed(&ignoredOutputPartitionerCircuit{}, &ignoredOutputPartitionerCircuit{ + Pivot: -1, + In: [2]frontend.Variable{10, 20}, + }) + + assert.ProverFailed(&ignoredOutputPartitionerCircuit{}, &ignoredOutputPartitionerCircuit{ + Pivot: 3, + In: [2]frontend.Variable{10, 20}, + }) +} + +type slicerCircuit struct { + Start, End frontend.Variable `gnark:",public"` + In [7]frontend.Variable `gnark:",public"` + WantSlice [7]frontend.Variable `gnark:",public"` +} + +func (c *slicerCircuit) Define(api frontend.API) error { + gotSlice := selector.Slice(api, c.Start, c.End, c.In[:]) + for i, want := range c.WantSlice { + api.AssertIsEqual(gotSlice[i], want) + } + return nil +} + +func TestSlice(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 2, + End: 5, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 2, 3, 4, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 4, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 3, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 1, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 6, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 4, 5, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 7, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 4, 5, 6}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 0, + End: 2, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 1, 0, 0, 0, 0, 0}, + }) + + assert.ProverSucceeded(&slicerCircuit{}, &slicerCircuit{ + Start: 0, + End: 7, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + }) + + assert.ProverFailed(&slicerCircuit{}, &slicerCircuit{ + Start: 3, + End: 8, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 0, 0, 3, 4, 5, 6}, + }) + + assert.ProverFailed(&slicerCircuit{}, &slicerCircuit{ + Start: -1, + End: 2, + In: [7]frontend.Variable{0, 1, 2, 3, 4, 5, 6}, + WantSlice: [7]frontend.Variable{0, 1, 0, 0, 0, 0, 0}, + }) +} diff --git a/std/signature/ecdsa/ecdsa.go b/std/signature/ecdsa/ecdsa.go index 5279f69930..cac32d7db1 100644 --- a/std/signature/ecdsa/ecdsa.go +++ b/std/signature/ecdsa/ecdsa.go @@ -1,7 +1,7 @@ /* Package ecdsa implements ECDSA signature verification over any elliptic curve. -The package depends on the [weierstrass] package for elliptic curve group +The package depends on the [emulated/sw_emulated] package for elliptic curve group operations using non-native arithmetic. Thus we can verify ECDSA signatures over any curve. The cost for a single secp256k1 signature verification is approximately 4M constraints in R1CS and 10M constraints in PLONKish. @@ -15,7 +15,7 @@ package ecdsa import ( "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/weierstrass" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" ) @@ -25,14 +25,14 @@ type Signature[Scalar emulated.FieldParams] struct { } // PublicKey represents the public key to verify the signature for. -type PublicKey[Base, Scalar emulated.FieldParams] weierstrass.AffinePoint[Base] +type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base] // Verify asserts that the signature sig verifies for the message msg and public // key pk. The curve parameters params define the elliptic curve. // // We assume that the message msg is already hashed to the scalar field. -func (pk PublicKey[T, S]) Verify(api frontend.API, params weierstrass.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { - cr, err := weierstrass.New[T, S](api, params) +func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) { + cr, err := sw_emulated.New[T, S](api, params) if err != nil { // TODO: softer handling. panic(err) @@ -45,14 +45,13 @@ func (pk PublicKey[T, S]) Verify(api frontend.API, params weierstrass.CurveParam if err != nil { panic(err) } - pkpt := weierstrass.AffinePoint[T](pk) + pkpt := sw_emulated.AffinePoint[T](pk) sInv := scalarApi.Inverse(&sig.S) msInv := scalarApi.MulMod(msg, sInv) rsInv := scalarApi.MulMod(&sig.R, sInv) - qa := cr.ScalarMul(cr.Generator(), msInv) - qb := cr.ScalarMul(&pkpt, rsInv) - q := cr.Add(qa, qb) + // q = [rsInv]pkpt + [msInv]g + q := cr.JointScalarMulBase(&pkpt, rsInv, msInv) qx := baseApi.Reduce(&q.X) qxBits := baseApi.ToBits(qx) rbits := scalarApi.ToBits(&sig.R) diff --git a/std/signature/ecdsa/ecdsa_secpr_test.go b/std/signature/ecdsa/ecdsa_secpr_test.go new file mode 100644 index 0000000000..139865e48b --- /dev/null +++ b/std/signature/ecdsa/ecdsa_secpr_test.go @@ -0,0 +1,111 @@ +package ecdsa + +import ( + cryptoecdsa "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +func TestEcdsaP256PreHashed(t *testing.T) { + + // generate parameters + privKey, _ := cryptoecdsa.GenerateKey(elliptic.P256(), rand.Reader) + publicKey := privKey.PublicKey + + // sign + msg := []byte("testing ECDSA (pre-hashed)") + msgHash := sha256.Sum256(msg) + sigBin, _ := privKey.Sign(rand.Reader, msgHash[:], nil) + + // check that the signature is correct + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sigBin) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + panic("invalid sig") + } + flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) + if !flag { + t.Errorf("can't verify signature") + } + + circuit := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{} + witness := EcdsaCircuit[emulated.P256Fp, emulated.P256Fr]{ + Sig: Signature[emulated.P256Fr]{ + R: emulated.ValueOf[emulated.P256Fr](r), + S: emulated.ValueOf[emulated.P256Fr](s), + }, + Msg: emulated.ValueOf[emulated.P256Fr](msgHash[:]), + Pub: PublicKey[emulated.P256Fp, emulated.P256Fr]{ + X: emulated.ValueOf[emulated.P256Fp](privKey.PublicKey.X), + Y: emulated.ValueOf[emulated.P256Fp](privKey.PublicKey.Y), + }, + } + assert := test.NewAssert(t) + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} + +func TestEcdsaP384PreHashed(t *testing.T) { + + // generate parameters + privKey, _ := cryptoecdsa.GenerateKey(elliptic.P384(), rand.Reader) + publicKey := privKey.PublicKey + + // sign + msg := []byte("testing ECDSA (pre-hashed)") + msgHash := sha512.Sum384(msg) + sigBin, _ := privKey.Sign(rand.Reader, msgHash[:], nil) + + // check that the signature is correct + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sigBin) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + panic("invalid sig") + } + flag := cryptoecdsa.Verify(&publicKey, msgHash[:], r, s) + if !flag { + t.Errorf("can't verify signature") + } + + circuit := EcdsaCircuit[emulated.P384Fp, emulated.P384Fr]{} + witness := EcdsaCircuit[emulated.P384Fp, emulated.P384Fr]{ + Sig: Signature[emulated.P384Fr]{ + R: emulated.ValueOf[emulated.P384Fr](r), + S: emulated.ValueOf[emulated.P384Fr](s), + }, + Msg: emulated.ValueOf[emulated.P384Fr](msgHash[:]), + Pub: PublicKey[emulated.P384Fp, emulated.P384Fr]{ + X: emulated.ValueOf[emulated.P384Fp](privKey.PublicKey.X), + Y: emulated.ValueOf[emulated.P384Fp](privKey.PublicKey.Y), + }, + } + assert := test.NewAssert(t) + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + +} diff --git a/std/signature/ecdsa/ecdsa_test.go b/std/signature/ecdsa/ecdsa_test.go index 711ec69554..57ff1a4406 100644 --- a/std/signature/ecdsa/ecdsa_test.go +++ b/std/signature/ecdsa/ecdsa_test.go @@ -9,7 +9,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/secp256k1/ecdsa" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/weierstrass" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/test" ) @@ -21,7 +21,7 @@ type EcdsaCircuit[T, S emulated.FieldParams] struct { } func (c *EcdsaCircuit[T, S]) Define(api frontend.API) error { - c.Pub.Verify(api, weierstrass.GetCurveParams[T](), &c.Msg, &c.Sig) + c.Pub.Verify(api, sw_emulated.GetCurveParams[T](), &c.Msg, &c.Sig) return nil } @@ -134,7 +134,7 @@ func ExamplePublicKey_Verify() { Y: emulated.ValueOf[emulated.Secp256k1Fp](puby), } // signature verification assertion is done in-circuit - Pub.Verify(api, weierstrass.GetCurveParams[emulated.Secp256k1Fp](), &Msg, &Sig) + Pub.Verify(api, sw_emulated.GetCurveParams[emulated.Secp256k1Fp](), &Msg, &Sig) } // Example how to create a valid signature for secp256k1 diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index e285f00496..b0841de4ad 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -24,7 +24,7 @@ import ( "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/algebra/native/twistededwards" tedwards "github.com/consensys/gnark-crypto/ecc/twistededwards" @@ -55,7 +55,7 @@ type Signature struct { // Verify verifies an eddsa signature using MiMC hash function // cf https://en.wikipedia.org/wiki/EdDSA -func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.Hash) error { +func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.FieldHasher) error { // compute H(R, A, M) hash.Write(sig.R.X) diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 657d337010..762bfb5ef0 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -27,7 +27,7 @@ import ( "github.com/consensys/gnark-crypto/signature/eddsa" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/std/algebra/twistededwards" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) diff --git a/std/utils/algo_utils/algo_utils.go b/std/utils/algo_utils/algo_utils.go new file mode 100644 index 0000000000..37dbc8c464 --- /dev/null +++ b/std/utils/algo_utils/algo_utils.go @@ -0,0 +1,175 @@ +package algo_utils + +import "github.com/bits-and-blooms/bitset" + +// this package provides some generic (in both senses of the word) algorithmic conveniences. + +// Permute operates in-place but is not thread-safe; it uses the permutation for scratching +// permutation[i] signifies which index slice[i] is going to +func Permute[T any](slice []T, permutation []int) { + var cached T + for next := 0; next < len(permutation); next++ { + + cached = slice[next] + j := permutation[next] + permutation[next] = ^j + for j >= 0 { + cached, slice[j] = slice[j], cached + j, permutation[j] = permutation[j], ^permutation[j] + } + permutation[next] = ^permutation[next] + } + for i := range permutation { + permutation[i] = ^permutation[i] + } +} + +func Map[T, S any](in []T, f func(T) S) []S { + out := make([]S, len(in)) + for i, t := range in { + out[i] = f(t) + } + return out +} + +func MapRange[S any](begin, end int, f func(int) S) []S { + out := make([]S, end-begin) + for i := begin; i < end; i++ { + out[i] = f(i) + } + return out +} + +func SliceAt[T any](slice []T) func(int) T { + return func(i int) T { + return slice[i] + } +} + +func SlicePtrAt[T any](slice []T) func(int) *T { + return func(i int) *T { + return &slice[i] + } +} + +func MapAt[K comparable, V any](mp map[K]V) func(K) V { + return func(k K) V { + return mp[k] + } +} + +// InvertPermutation input permutation must contain exactly 0, ..., len(permutation)-1 +func InvertPermutation(permutation []int) []int { + res := make([]int, len(permutation)) + for i := range permutation { + res[permutation[i]] = i + } + return res +} + +// TODO: Move this to gnark-crypto and use it for gkr there as well + +// TopologicalSort takes a list of lists of dependencies and proposes a sorting of the lists in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// As a bonus, it returns for each list its "unique" outputs. That is, a list of its outputs with no duplicates. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input. +// If performance was bad, consider using a heap for finding the value "leastReady". +// WARNING: Due to the current implementation of intSet, it is ALWAYS O(n^2). +func TopologicalSort(inputs [][]int) (sorted []int, uniqueOutputs [][]int) { + data := newTopSortData(inputs) + sorted = make([]int, len(inputs)) + + for i := range inputs { + sorted[i] = data.leastReady + data.markDone(data.leastReady) + } + + return sorted, data.uniqueOutputs +} + +type topSortData struct { + uniqueOutputs [][]int + inputs [][]int + status []int // status > 0 indicates number of unique inputs left to be ready. status = 0 means ready. status = -1 means done + leastReady int +} + +func newTopSortData(inputs [][]int) topSortData { + size := len(inputs) + res := topSortData{ + uniqueOutputs: make([][]int, size), + inputs: inputs, + status: make([]int, size), + leastReady: 0, + } + + inputsISet := bitset.New(uint(size)) + for i := range res.uniqueOutputs { + if i != 0 { + inputsISet.ClearAll() + } + cpt := 0 + for _, in := range inputs[i] { + if !inputsISet.Test(uint(in)) { + inputsISet.Set(uint(in)) + cpt++ + res.uniqueOutputs[in] = append(res.uniqueOutputs[in], i) + } + } + res.status[i] = cpt + } + + for res.status[res.leastReady] != 0 { + res.leastReady++ + } + + return res +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.uniqueOutputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +// BinarySearch looks for toFind in a sorted slice, and returns the index at which it either is or would be were it to be inserted. +func BinarySearch(slice []int, toFind int) int { + var start int + for end := len(slice); start != end; { + mid := (start + end) / 2 + if toFind >= slice[mid] { + start = mid + } + if toFind <= slice[mid] { + end = mid + } + } + return start +} + +// BinarySearchFunc looks for toFind in an increasing function of domain 0 ... (end-1), and returns the index at which it either is or would be were it to be inserted. +func BinarySearchFunc(eval func(int) int, end int, toFind int) int { + var start int + for start != end { + mid := (start + end) / 2 + val := eval(mid) + if toFind >= val { + start = mid + } + if toFind <= val { + end = mid + } + } + return start +} diff --git a/std/utils/algo_utils/algo_utils_test.go b/std/utils/algo_utils/algo_utils_test.go new file mode 100644 index 0000000000..85ab4bf294 --- /dev/null +++ b/std/utils/algo_utils/algo_utils_test.go @@ -0,0 +1,69 @@ +package algo_utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func SliceLen[T any](slice []T) int { + return len(slice) +} + +func testTopSort(t *testing.T, inputs [][]int, expectedSorted, expectedNbUniqueOuts []int) { + sorted, uniqueOuts := TopologicalSort(inputs) + nbUniqueOut := Map(uniqueOuts, SliceLen[int]) + assert.Equal(t, expectedSorted, sorted) + assert.Equal(t, expectedNbUniqueOuts, nbUniqueOut) +} + +func TestTopSortTrivial(t *testing.T) { + testTopSort(t, [][]int{ + {1}, + {}, + }, []int{1, 0}, []int{0, 1}) +} + +func TestTopSortSingleGate(t *testing.T) { + inputs := [][]int{{1, 2}, {}, {}} + expectedSorted := []int{1, 2, 0} + expectedNbUniqueOuts := []int{0, 1, 1} + testTopSort(t, inputs, expectedSorted, expectedNbUniqueOuts) +} + +func TestTopSortDeep(t *testing.T) { + inputs := [][]int{{2}, {3}, {}, {0}} + expectedSorted := []int{2, 0, 3, 1} + expectedNbUniqueOuts := []int{1, 0, 1, 1} + + testTopSort(t, inputs, expectedSorted, expectedNbUniqueOuts) +} + +func TestTopSortWide(t *testing.T) { + inputs := [][]int{ + {3, 8}, + {6}, + {4}, + {}, + {}, + {9}, + {9}, + {9, 5, 2, 2}, + {4, 3}, + {}, + } + expectedSorted := []int{3, 4, 2, 8, 0, 9, 5, 6, 1, 7} + expectedNbUniqueOut := []int{0, 0, 1, 2, 2, 1, 1, 0, 1, 3} + + testTopSort(t, inputs, expectedSorted, expectedNbUniqueOut) +} + +func TestPermute(t *testing.T) { + list := []int{34, 65, 23, 2, 5} + permutation := []int{2, 0, 1, 4, 3} + permutationCopy := make([]int, len(permutation)) + copy(permutationCopy, permutation) + + Permute(list, permutation) + assert.Equal(t, []int{65, 23, 34, 5, 2}, list) + assert.Equal(t, permutationCopy, permutation) +} diff --git a/std/utils/test_vectors_utils/test_vector_utils.go b/std/utils/test_vectors_utils/test_vector_utils.go new file mode 100644 index 0000000000..2f5dbc4a38 --- /dev/null +++ b/std/utils/test_vectors_utils/test_vector_utils.go @@ -0,0 +1,260 @@ +package test_vector_utils + +import ( + "encoding/json" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +// These data structures fail to equate different representations of the same number. i.e. 5 = -10/-2 +// @Tabaie TODO Replace with proper lookup tables + +type Map struct { + keys []frontend.Variable + values []frontend.Variable +} + +func getDelta(api frontend.API, x frontend.Variable, deltaIndex int, keys []frontend.Variable) frontend.Variable { + num := frontend.Variable(1) + den := frontend.Variable(1) + + for i, key := range keys { + if i != deltaIndex { + num = api.Mul(num, api.Sub(key, x)) + den = api.Mul(den, api.Sub(key, keys[deltaIndex])) + } + } + + return api.Div(num, den) +} + +// Get returns garbage if key is not present +func (m Map) Get(api frontend.API, key frontend.Variable) frontend.Variable { + res := frontend.Variable(0) + + for i := range m.keys { + deltaI := getDelta(api, key, i, m.keys) + res = api.Add(res, api.Mul(deltaI, m.values[i])) + } + + return res +} + +// The keys in a DoubleMap must be constant. i.e. known at setup time +type DoubleMap struct { + keys1 []frontend.Variable + keys2 []frontend.Variable + values [][]frontend.Variable +} + +// Get is very inefficient. Do not use outside testing +func (m DoubleMap) Get(api frontend.API, key1, key2 frontend.Variable) frontend.Variable { + deltas1 := make([]frontend.Variable, len(m.keys1)) + deltas2 := make([]frontend.Variable, len(m.keys2)) + + for i := range deltas1 { + deltas1[i] = getDelta(api, key1, i, m.keys1) + } + + for j := range deltas2 { + deltas2[j] = getDelta(api, key2, j, m.keys2) + } + + res := frontend.Variable(0) + + for i := range deltas1 { + for j := range deltas2 { + if m.values[i][j] != nil { + deltaIJ := api.Mul(deltas1[i], deltas2[j], m.values[i][j]) + res = api.Add(res, deltaIJ) + } + } + } + + return res +} + +func register[K comparable](m map[K]int, key K) { + if _, ok := m[key]; !ok { + m[key] = len(m) + } +} + +func orderKeys[K comparable](order map[K]int) (ordered []K) { + ordered = make([]K, len(order)) + for k, i := range order { + ordered[i] = k + } + return +} + +type ElementMap struct { + single Map + double DoubleMap +} + +func ReadMap(in map[string]interface{}) ElementMap { + single := Map{ + keys: make([]frontend.Variable, 0), + values: make([]frontend.Variable, 0), + } + + keys1 := make(map[string]int) + keys2 := make(map[string]int) + + for k, v := range in { + + kSep := strings.Split(k, ",") + switch len(kSep) { + case 1: + single.keys = append(single.keys, k) + single.values = append(single.values, ToVariable(v)) + case 2: + + register(keys1, kSep[0]) + register(keys2, kSep[1]) + + default: + panic("too many keys") + } + } + + vals := make([][]frontend.Variable, len(keys1)) + for i := range vals { + vals[i] = make([]frontend.Variable, len(keys2)) + } + + for k, v := range in { + kSep := strings.Split(k, ",") + if len(kSep) == 2 { + i1 := keys1[kSep[0]] + i2 := keys2[kSep[1]] + vals[i1][i2] = ToVariable(v) + } + } + + double := DoubleMap{ + keys1: ToVariableSlice(orderKeys(keys1)), + keys2: ToVariableSlice(orderKeys(keys2)), + values: vals, + } + + return ElementMap{ + single: single, + double: double, + } +} + +func ToVariable(v interface{}) frontend.Variable { + switch vT := v.(type) { + case float64: + return int(vT) + default: + return v + } +} + +func ToVariableSlice[V any](slice []V) (variableSlice []frontend.Variable) { + variableSlice = make([]frontend.Variable, len(slice)) + for i := range slice { + variableSlice[i] = ToVariable(slice[i]) + } + return +} + +func ToVariableSliceSlice[V any](sliceSlice [][]V) (variableSliceSlice [][]frontend.Variable) { + variableSliceSlice = make([][]frontend.Variable, len(sliceSlice)) + for i := range sliceSlice { + variableSliceSlice[i] = ToVariableSlice(sliceSlice[i]) + } + return +} + +func ToMap(keys1, keys2, values []frontend.Variable) map[string]interface{} { + res := make(map[string]interface{}, len(keys1)) + for i := range keys1 { + str := strconv.Itoa(keys1[i].(int)) + "," + strconv.Itoa(keys2[i].(int)) + res[str] = values[i].(int) + } + return res +} + +var MapCache = make(map[string]ElementMap) // @Tabaie: global bad? + +func ElementMapFromFile(path string) (ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return ElementMap{}, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return ElementMap{}, err + } + + res := ReadMap(asMap) + MapCache[path] = res + return res, nil + + } else { + return ElementMap{}, err + } +} + +type MapHash struct { + Map ElementMap + state frontend.Variable + API frontend.API + stateValid bool +} + +func (m *MapHash) Sum() frontend.Variable { + return m.state +} + +func (m *MapHash) Write(data ...frontend.Variable) { + for _, x := range data { + m.write(x) + } +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) write(x frontend.Variable) { + if m.stateValid { + m.state = m.Map.double.Get(m.API, x, m.state) + } else { + m.state = m.Map.single.Get(m.API, x) + } + m.stateValid = true +} + +func AssertSliceEqual[T comparable](t *testing.T, expected, seen []T) { + assert.Equal(t, len(expected), len(seen)) + for i := range seen { + assert.True(t, expected[i] == seen[i], "@%d: %v != %v", i, expected[i], seen[i]) // assert.Equal is not strict enough when comparing pointers, i.e. it compares what they refer to + } +} + +func SliceEqual[T comparable](expected, seen []T) bool { + if len(expected) != len(seen) { + return false + } + for i := range seen { + if expected[i] != seen[i] { + return false + } + } + return true +} diff --git a/std/utils/test_vectors_utils/test_vector_utils_test.go b/std/utils/test_vectors_utils/test_vector_utils_test.go new file mode 100644 index 0000000000..2193fb789e --- /dev/null +++ b/std/utils/test_vectors_utils/test_vector_utils_test.go @@ -0,0 +1,148 @@ +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/assert" + "testing" +) + +type TestSingleMapCircuit struct { + M Map `gnark:"-"` + Values []frontend.Variable +} + +func (c *TestSingleMapCircuit) Define(api frontend.API) error { + + for i, k := range c.M.keys { + v := c.M.Get(api, k) + api.AssertIsEqual(v, c.Values[i]) + } + + return nil +} + +func TestSingleMap(t *testing.T) { + m := map[string]interface{}{ + "1": -2, + "4": 1, + "6": 7, + } + single := ReadMap(m).single + + assignment := TestSingleMapCircuit{ + M: single, + Values: single.values, + } + + circuit := TestSingleMapCircuit{ + M: single, + Values: make([]frontend.Variable, len(m)), // Okay to use the same object? + } + + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithBackends(backend.GROTH16)) +} + +type TestDoubleMapCircuit struct { + M DoubleMap `gnark:"-"` + Values []frontend.Variable + Keys1 []frontend.Variable `gnark:"-"` + Keys2 []frontend.Variable `gnark:"-"` +} + +func (c *TestDoubleMapCircuit) Define(api frontend.API) error { + + for i := range c.Keys1 { + v := c.M.Get(api, c.Keys1[i], c.Keys2[i]) + api.AssertIsEqual(v, c.Values[i]) + } + + return nil +} + +func TestReadDoubleMap(t *testing.T) { + keys1 := []frontend.Variable{1, 2} + keys2 := []frontend.Variable{1, 0} + values := []frontend.Variable{3, 1} + + for i := 0; i < 100; i++ { + m := ToMap(keys1, keys2, values) + double := ReadMap(m).double + valuesOrdered := [][]frontend.Variable{{3, nil}, {nil, 1}} + + assert.True(t, double.keys1[0] == "1" && double.keys1[1] == "2" || double.keys1[0] == "2" && double.keys1[1] == "1") + assert.True(t, double.keys2[0] == "1" && double.keys2[1] == "0" || double.keys2[0] == "0" && double.keys2[1] == "1") + + if double.keys1[0] != "1" { + valuesOrdered[0], valuesOrdered[1] = valuesOrdered[1], valuesOrdered[0] + } + + if double.keys2[0] != "1" { + valuesOrdered[0][0], valuesOrdered[0][1] = valuesOrdered[0][1], valuesOrdered[0][0] + valuesOrdered[1][0], valuesOrdered[1][1] = valuesOrdered[1][1], valuesOrdered[1][0] + } + + assert.True(t, slice2Eq(valuesOrdered, double.values)) + + } + +} + +func slice2Eq(s1, s2 [][]frontend.Variable) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if !sliceEq(s1[i], s2[i]) { + return false + } + } + return true +} + +func sliceEq(s1, s2 []frontend.Variable) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} + +func TestDoubleMap(t *testing.T) { + keys1 := []frontend.Variable{1, 5, 5, 3} + keys2 := []frontend.Variable{1, -5, 4, 4} + values := []frontend.Variable{0, 2, 3, 0} + + m := ToMap(keys1, keys2, values) + double := ReadMap(m).double + + fmt.Println(double) + + assignment := TestDoubleMapCircuit{ + M: double, + Values: values, + Keys1: keys1, + Keys2: keys2, + } + + circuit := TestDoubleMapCircuit{ + M: double, + Keys1: keys1, + Keys2: keys2, + Values: make([]frontend.Variable, len(m)), // Okay to use the same object? + } + + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithBackends(backend.GROTH16)) +} + +func TestDoubleMapManyTimes(t *testing.T) { + for i := 0; i < 100; i++ { + TestDoubleMap(t) + } +} diff --git a/test/assert.go b/test/assert.go index b6a077d2a1..10d1ebf977 100644 --- a/test/assert.go +++ b/test/assert.go @@ -17,8 +17,10 @@ limitations under the License. package test import ( + "bytes" "errors" "fmt" + "io" "reflect" "strings" "testing" @@ -35,6 +37,7 @@ import ( "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/frontend/schema" + gnarkio "github.com/consensys/gnark/io" "github.com/stretchr/testify/require" ) @@ -44,6 +47,10 @@ var ( ErrInvalidWitnessVerified = errors.New("invalid witness resulted in a valid proof") ) +// SerializationThreshold is the number of constraints above which we don't +// do a systematic round-trip serialization check for the proving and verifying keys. +const SerializationThreshold = 1000 + // Assert is a helper to test circuits type Assert struct { t *testing.T @@ -149,6 +156,16 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment case backend.GROTH16: pk, vk, err := groth16.Setup(ccs) checkError(err) + if ccs.GetNbConstraints() <= SerializationThreshold { + pkReconstructed := groth16.NewProvingKey(curve) + roundTripCheck(assert.t, pk, pkReconstructed) + pkReconstructed = groth16.NewProvingKey(curve) + roundTripCheckRaw(assert.t, pk, pkReconstructed) + vkReconstructed := groth16.NewVerifyingKey(curve) + roundTripCheck(assert.t, vk, vkReconstructed) + vkReconstructed = groth16.NewVerifyingKey(curve) + roundTripCheckRaw(assert.t, vk, vkReconstructed) + } // ensure prove / verify works well with valid witnesses @@ -158,19 +175,35 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment err = groth16.Verify(proof, vk, validPublicWitness) checkError(err) + if opt.solidity && curve == ecc.BN254 && vk.NbPublicWitness() > 0 { + // check that the proof can be verified by gnark-solidity-checker + assert.solidityVerification(b, vk, proof, validPublicWitness) + } + case backend.PLONK: srs, err := NewKZGSRS(ccs) checkError(err) pk, vk, err := plonk.Setup(ccs, srs) checkError(err) + if ccs.GetNbConstraints() <= SerializationThreshold { + pkReconstructed := plonk.NewProvingKey(curve) + roundTripCheck(assert.t, pk, pkReconstructed) + vkReconstructed := plonk.NewVerifyingKey(curve) + roundTripCheck(assert.t, vk, vkReconstructed) + } - correctProof, err := plonk.Prove(ccs, pk, validWitness, opt.proverOpts...) + proof, err := plonk.Prove(ccs, pk, validWitness, opt.proverOpts...) checkError(err) - err = plonk.Verify(correctProof, vk, validPublicWitness) + err = plonk.Verify(proof, vk, validPublicWitness) checkError(err) + if opt.solidity && curve == ecc.BN254 { + // check that the proof can be verified by gnark-solidity-checker + assert.solidityVerification(b, vk, proof, validPublicWitness) + } + case backend.PLONKFRI: pk, vk, err := plonkfri.Setup(ccs) checkError(err) @@ -208,15 +241,11 @@ func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment f opt := assert.options(opts...) - popts := append(opt.proverOpts, backend.IgnoreSolverError()) - for _, curve := range opt.curves { // parse assignment invalidWitness, err := frontend.NewWitness(invalidAssignment, curve.ScalarField()) assert.NoError(err, "can't parse invalid assignment") - invalidPublicWitness, err := frontend.NewWitness(invalidAssignment, curve.ScalarField(), frontend.PublicOnly()) - assert.NoError(err, "can't parse invalid assignment") for _, b := range opt.backends { curve := curve @@ -235,42 +264,8 @@ func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment f mustError(err) assert.t.Parallel() - err = ccs.IsSolved(invalidPublicWitness) + err = ccs.IsSolved(invalidWitness) mustError(err) - - switch b { - case backend.GROTH16: - pk, vk, err := groth16.Setup(ccs) - checkError(err) - - proof, _ := groth16.Prove(ccs, pk, invalidWitness, popts...) - - err = groth16.Verify(proof, vk, invalidPublicWitness) - mustError(err) - - case backend.PLONK: - srs, err := NewKZGSRS(ccs) - checkError(err) - - pk, vk, err := plonk.Setup(ccs, srs) - checkError(err) - - incorrectProof, _ := plonk.Prove(ccs, pk, invalidWitness, popts...) - err = plonk.Verify(incorrectProof, vk, invalidPublicWitness) - mustError(err) - - case backend.PLONKFRI: - - pk, vk, err := plonkfri.Setup(ccs) - checkError(err) - - incorrectProof, _ := plonkfri.Prove(ccs, pk, invalidWitness, popts...) - err = plonkfri.Verify(incorrectProof, vk, invalidPublicWitness) - mustError(err) - - default: - panic("backend not implemented") - } }, curve.String(), b.String()) } } @@ -306,7 +301,7 @@ func (assert *Assert) solvingSucceeded(circuit frontend.Circuit, validAssignment err = IsSolved(circuit, validAssignment, curve.ScalarField()) checkError(err) - err = ccs.IsSolved(validWitness, opt.proverOpts...) + err = ccs.IsSolved(validWitness, opt.solverOpts...) checkError(err) } @@ -352,7 +347,7 @@ func (assert *Assert) solvingFailed(circuit frontend.Circuit, invalidAssignment err = IsSolved(circuit, invalidAssignment, curve.ScalarField()) mustError(err) - err = ccs.IsSolved(invalidWitness, opt.proverOpts...) + err = ccs.IsSolved(invalidWitness, opt.solverOpts...) mustError(err) } @@ -405,8 +400,22 @@ func (assert *Assert) fuzzer(fuzzer filler, circuit, w frontend.Circuit, b backe errConsts := IsSolved(circuit, w, curve.ScalarField(), SetAllVariablesAsConstants()) if (errVars == nil) != (errConsts == nil) { + w, err := frontend.NewWitness(w, curve.ScalarField()) + if err != nil { + panic(err) + } + s, err := frontend.NewSchema(circuit) + if err != nil { + panic(err) + } + bb, err := w.ToJSON(s) + if err != nil { + panic(err) + } + assert.Log("errVars", errVars) assert.Log("errConsts", errConsts) + assert.Log("fuzzer witness", string(bb)) assert.FailNow("solving circuit with values as constants vs non-constants mismatched result") } @@ -578,3 +587,45 @@ func (assert *Assert) marshalWitnessJSON(w witness.Witness, s *schema.Schema, cu witnessMatch := reflect.DeepEqual(w, witness) assert.True(witnessMatch, "round trip marshaling failed") } + +func roundTripCheck(t *testing.T, from io.WriterTo, reconstructed io.ReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} + +func roundTripCheckRaw(t *testing.T, from gnarkio.WriterRawTo, reconstructed io.ReaderFrom) { + var buf bytes.Buffer + written, err := from.WriteRawTo(&buf) + if err != nil { + t.Fatal("couldn't serialize", err) + } + + read, err := reconstructed.ReadFrom(&buf) + if err != nil { + t.Fatal("couldn't deserialize", err) + } + + if !reflect.DeepEqual(from, reconstructed) { + t.Fatal("reconstructed object don't match original") + } + + if written != read { + t.Fatal("bytes written / read don't match") + } +} diff --git a/test/assert_solidity.go b/test/assert_solidity.go new file mode 100644 index 0000000000..1f027ac2b7 --- /dev/null +++ b/test/assert_solidity.go @@ -0,0 +1,95 @@ +package test + +import ( + "bytes" + "encoding/hex" + "io" + "os" + "os/exec" + "path/filepath" + "strconv" + + "github.com/consensys/gnark/backend" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + plonk_bn254 "github.com/consensys/gnark/backend/plonk/bn254" + "github.com/consensys/gnark/backend/witness" +) + +// solidityVerification checks that the exported solidity contract can verify the proof +// and that the proof is valid. +// It uses gnark-solidity-checker see test.WithSolidity option. +func (assert *Assert) solidityVerification(b backend.ID, vk interface { + NbPublicWitness() int + ExportSolidity(io.Writer) error +}, + proof any, + validPublicWitness witness.Witness) { + if !solcCheck { + return // nothing to check, will make solc fail. + } + assert.t.Helper() + + // make temp dir + tmpDir, err := os.MkdirTemp("", "gnark-solidity-check*") + assert.NoError(err) + defer os.RemoveAll(tmpDir) + + // export solidity contract + fSolidity, err := os.Create(filepath.Join(tmpDir, "gnark_verifier.sol")) + assert.NoError(err) + + err = vk.ExportSolidity(fSolidity) + assert.NoError(err) + + err = fSolidity.Close() + assert.NoError(err) + + // generate assets + // gnark-solidity-checker generate --dir tmpdir --solidity contract_g16.sol + cmd := exec.Command("gnark-solidity-checker", "generate", "--dir", tmpDir, "--solidity", "gnark_verifier.sol") + assert.t.Log("running ", cmd.String()) + out, err := cmd.CombinedOutput() + assert.NoError(err, string(out)) + + // proof to hex + var proofStr string + var optBackend string + + if b == backend.GROTH16 { + optBackend = "--groth16" + var buf bytes.Buffer + _proof := proof.(*groth16_bn254.Proof) + _, err = _proof.WriteRawTo(&buf) + assert.NoError(err) + proofStr = hex.EncodeToString(buf.Bytes()) + } else if b == backend.PLONK { + optBackend = "--plonk" + _proof := proof.(*plonk_bn254.Proof) + // TODO @gbotrel make a single Marshal function for PlonK proof. + proofStr = hex.EncodeToString(_proof.MarshalSolidity()) + } else { + panic("not implemented") + } + + // public witness to hex + bPublicWitness, err := validPublicWitness.MarshalBinary() + assert.NoError(err) + // that's quite dirty... + // first 4 bytes -> nbPublic + // next 4 bytes -> nbSecret + // next 4 bytes -> nb elements in the vector (== nbPublic + nbSecret) + bPublicWitness = bPublicWitness[12:] + publicWitnessStr := hex.EncodeToString(bPublicWitness) + + // verify proof + // gnark-solidity-checker verify --dir tmdir --groth16 --nb-public-inputs 1 --proof 1234 --public-inputs dead + cmd = exec.Command("gnark-solidity-checker", "verify", + "--dir", tmpDir, + optBackend, + "--nb-public-inputs", strconv.Itoa(vk.NbPublicWitness()), + "--proof", proofStr, + "--public-inputs", publicWitnessStr) + assert.t.Log("running ", cmd.String()) + out, err = cmd.CombinedOutput() + assert.NoError(err, string(out)) +} diff --git a/test/blueprint_solver.go b/test/blueprint_solver.go new file mode 100644 index 0000000000..9ced90ee6a --- /dev/null +++ b/test/blueprint_solver.go @@ -0,0 +1,140 @@ +package test + +import ( + "math/big" + + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/utils" +) + +// blueprintSolver is a constraint.Solver that can be used to test a circuit +// it is a separate type to avoid method collisions with the engine. +type blueprintSolver struct { + internalVariables []*big.Int + q *big.Int +} + +// implements constraint.Solver + +func (s *blueprintSolver) SetValue(vID uint32, f constraint.Element) { + if int(vID) > len(s.internalVariables) { + panic("out of bounds") + } + v := s.ToBigInt(f) + s.internalVariables[vID].Set(v) +} + +func (s *blueprintSolver) GetValue(cID, vID uint32) constraint.Element { + panic("not implemented in test.Engine") +} +func (s *blueprintSolver) GetCoeff(cID uint32) constraint.Element { + panic("not implemented in test.Engine") +} + +func (s *blueprintSolver) IsSolved(vID uint32) bool { + panic("not implemented in test.Engine") +} + +// implements constraint.Field + +func (s *blueprintSolver) FromInterface(i interface{}) constraint.Element { + b := utils.FromInterface(i) + return s.toElement(&b) +} + +func (s *blueprintSolver) ToBigInt(f constraint.Element) *big.Int { + r := new(big.Int) + fBytes := f.Bytes() + r.SetBytes(fBytes[:]) + return r +} +func (s *blueprintSolver) Mul(a, b constraint.Element) constraint.Element { + ba, bb := s.ToBigInt(a), s.ToBigInt(b) + ba.Mul(ba, bb).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Add(a, b constraint.Element) constraint.Element { + ba, bb := s.ToBigInt(a), s.ToBigInt(b) + ba.Add(ba, bb).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Sub(a, b constraint.Element) constraint.Element { + ba, bb := s.ToBigInt(a), s.ToBigInt(b) + ba.Sub(ba, bb).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Neg(a constraint.Element) constraint.Element { + ba := s.ToBigInt(a) + ba.Neg(ba).Mod(ba, s.q) + return s.toElement(ba) +} +func (s *blueprintSolver) Inverse(a constraint.Element) (constraint.Element, bool) { + ba := s.ToBigInt(a) + r := ba.ModInverse(ba, s.q) + return s.toElement(ba), r != nil +} +func (s *blueprintSolver) One() constraint.Element { + b := new(big.Int).SetUint64(1) + return s.toElement(b) +} +func (s *blueprintSolver) IsOne(a constraint.Element) bool { + b := s.ToBigInt(a) + return b.IsUint64() && b.Uint64() == 1 +} + +func (s *blueprintSolver) String(a constraint.Element) string { + b := s.ToBigInt(a) + return b.String() +} + +func (s *blueprintSolver) Uint64(a constraint.Element) (uint64, bool) { + b := s.ToBigInt(a) + return b.Uint64(), b.IsUint64() +} + +func (s *blueprintSolver) Read(calldata []uint32) (constraint.Element, int) { + // We encoded big.Int as constraint.Element on 12 uint32 words. + var r constraint.Element + for i := 0; i < len(r); i++ { + index := i * 2 + r[i] = uint64(calldata[index])<<32 | uint64(calldata[index+1]) + } + return r, len(r) * 2 +} + +func (s *blueprintSolver) toElement(b *big.Int) constraint.Element { + return bigIntToElement(b) +} + +func bigIntToElement(b *big.Int) constraint.Element { + if b.Sign() == -1 { + panic("negative value") + } + bytes := b.Bytes() + if len(bytes) > 48 { + panic("value too big") + } + var paddedBytes [48]byte + copy(paddedBytes[48-len(bytes):], bytes[:]) + + var r constraint.Element + r.SetBytes(paddedBytes) + + return r +} + +// wrappedBigInt is a wrapper around big.Int to implement the frontend.CanonicalVariable interface +type wrappedBigInt struct { + *big.Int +} + +func (w wrappedBigInt) Compress(to *[]uint32) { + // convert to Element. + e := bigIntToElement(w.Int) + + // append the uint32 words to the slice + for i := 0; i < len(e); i++ { + *to = append(*to, uint32(e[i]>>32)) + *to = append(*to, uint32(e[i]&0xffffffff)) + } +} diff --git a/test/blueprint_solver_test.go b/test/blueprint_solver_test.go new file mode 100644 index 0000000000..0cc42b6a27 --- /dev/null +++ b/test/blueprint_solver_test.go @@ -0,0 +1,64 @@ +package test + +import ( + "math/big" + "math/rand" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" +) + +func TestBigIntToElement(t *testing.T) { + t.Parallel() + // sample a random big.Int, convert it to an element, and back + // to a big.Int, and check that it's the same + s := blueprintSolver{q: ecc.BN254.ScalarField()} + b := big.NewInt(0) + for i := 0; i < 50; i++ { + b.Rand(rand.New(rand.NewSource(time.Now().Unix())), s.q) //#nosec G404 -- This is a false positive + e := s.toElement(b) + b2 := s.ToBigInt(e) + if b.Cmp(b2) != 0 { + t.Fatal("b != b2") + } + } + +} + +func TestBigIntToUint32Slice(t *testing.T) { + t.Parallel() + // sample a random big.Int, write it to a uint32 slice, and back + // to a big.Int, and check that it's the same + s := blueprintSolver{q: ecc.BN254.ScalarField()} + b1 := big.NewInt(0) + b2 := big.NewInt(0) + + for i := 0; i < 50; i++ { + b1.Rand(rand.New(rand.NewSource(time.Now().Unix())), s.q) //#nosec G404 -- This is a false positive + b2.Rand(rand.New(rand.NewSource(time.Now().Unix())), s.q) //#nosec G404 -- This is a false positive + wb1 := wrappedBigInt{b1} + wb2 := wrappedBigInt{b2} + var to []uint32 + wb1.Compress(&to) + wb2.Compress(&to) + + if len(to) != 24 { + t.Fatal("wrong length: expected 2*len of constraint.Element (uint32 words)") + } + + e1, n := s.Read(to) + if n != 12 { + t.Fatal("wrong length: expected 1 len of constraint.Element (uint32 words)") + } + e2, n := s.Read(to[n:]) + if n != 12 { + t.Fatal("wrong length: expected 1 len of constraint.Element (uint32 words)") + } + rb1, rb2 := s.ToBigInt(e1), s.ToBigInt(e2) + if rb1.Cmp(b1) != 0 || rb2.Cmp(b2) != 0 { + t.Fatal("rb1 != b1 || rb2 != b2") + } + } + +} diff --git a/test/commitments_test.go b/test/commitments_test.go new file mode 100644 index 0000000000..2b62bb8f09 --- /dev/null +++ b/test/commitments_test.go @@ -0,0 +1,260 @@ +package test + +import ( + "fmt" + "github.com/consensys/gnark/backend" + groth16 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/backend/witness" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" + "reflect" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/stretchr/testify/assert" +) + +type noCommitmentCircuit struct { + X frontend.Variable +} + +func (c *noCommitmentCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.X, 1) + api.AssertIsEqual(c.X, 1) + return nil +} + +type commitmentCircuit struct { + Public []frontend.Variable `gnark:",public"` + X []frontend.Variable +} + +func (c *commitmentCircuit) Define(api frontend.API) error { + + commitment, err := api.(frontend.Committer).Commit(c.X...) + if err != nil { + return err + } + sum := frontend.Variable(0) + for i, x := range c.X { + sum = api.Add(sum, api.Mul(x, i+1)) + } + for _, p := range c.Public { + sum = api.Add(sum, p) + } + api.AssertIsDifferent(commitment, sum) + return nil +} + +type committedConstantCircuit struct { + X frontend.Variable +} + +func (c *committedConstantCircuit) Define(api frontend.API) error { + commitment, err := api.(frontend.Committer).Commit(1, c.X) + if err != nil { + return err + } + api.AssertIsDifferent(commitment, c.X) + return nil +} + +type committedPublicCircuit struct { + X frontend.Variable `gnark:",public"` +} + +func (c *committedPublicCircuit) Define(api frontend.API) error { + commitment, err := api.(frontend.Committer).Commit(c.X) + if err != nil { + return err + } + api.AssertIsDifferent(commitment, c.X) + return nil +} + +type independentCommitsCircuit struct { + X []frontend.Variable +} + +func (c *independentCommitsCircuit) Define(api frontend.API) error { + committer := api.(frontend.Committer) + for i := range c.X { + if ch, err := committer.Commit(c.X[i]); err != nil { + return err + } else { + api.AssertIsDifferent(ch, c.X[i]) + } + } + return nil +} + +type twoCommitCircuit struct { + X []frontend.Variable + Y frontend.Variable +} + +func (c *twoCommitCircuit) Define(api frontend.API) error { + c0, err := api.(frontend.Committer).Commit(c.X...) + if err != nil { + return err + } + var c1 frontend.Variable + if c1, err = api.(frontend.Committer).Commit(c0, c.Y); err != nil { + return err + } + api.AssertIsDifferent(c1, c.Y) + return nil +} + +type doubleCommitCircuit struct { + X, Y frontend.Variable +} + +func (c *doubleCommitCircuit) Define(api frontend.API) error { + var c0, c1 frontend.Variable + var err error + if c0, err = api.(frontend.Committer).Commit(c.X); err != nil { + return err + } + if c1, err = api.(frontend.Committer).Commit(c.X, c.Y); err != nil { + return err + } + api.AssertIsDifferent(c0, c1) + return nil +} + +func TestHollow(t *testing.T) { + + run := func(c, expected frontend.Circuit) func(t *testing.T) { + return func(t *testing.T) { + seen := hollow(c) + assert.Equal(t, expected, seen) + } + } + + assignments := []frontend.Circuit{ + &committedConstantCircuit{1}, + &commitmentCircuit{X: []frontend.Variable{1}, Public: []frontend.Variable{}}, + } + + expected := []frontend.Circuit{ + &committedConstantCircuit{nil}, + &commitmentCircuit{X: []frontend.Variable{nil}, Public: []frontend.Variable{}}, + } + + for i := range assignments { + t.Run(removePackageName(reflect.TypeOf(assignments[i]).String()), run(assignments[i], expected[i])) + } +} + +type commitUniquenessCircuit struct { + X []frontend.Variable +} + +func (c *commitUniquenessCircuit) Define(api frontend.API) error { + var err error + + ch := make([]frontend.Variable, len(c.X)) + for i := range c.X { + if ch[i], err = api.(frontend.Committer).Commit(c.X[i]); err != nil { + return err + } + for j := 0; j < i; j++ { + api.AssertIsDifferent(ch[i], ch[j]) + } + } + return nil +} + +func TestCommitUniquenessZerosScs(t *testing.T) { // TODO @Tabaie Randomize Groth16 commitments for real + + w, err := frontend.NewWitness(&commitUniquenessCircuit{[]frontend.Variable{0, 0}}, ecc.BN254.ScalarField()) + assert.NoError(t, err) + + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &commitUniquenessCircuit{[]frontend.Variable{nil, nil}}) + assert.NoError(t, err) + + _, err = ccs.Solve(w) + assert.NoError(t, err) +} + +var commitmentTestCircuits []frontend.Circuit + +func init() { + commitmentTestCircuits = []frontend.Circuit{ + &noCommitmentCircuit{1}, + &commitmentCircuit{X: []frontend.Variable{1}, Public: []frontend.Variable{}}, // single commitment + &commitmentCircuit{X: []frontend.Variable{1, 2}, Public: []frontend.Variable{}}, // two commitments + &commitmentCircuit{X: []frontend.Variable{1, 2, 3, 4, 5}, Public: []frontend.Variable{}}, // five commitments + &commitmentCircuit{X: []frontend.Variable{0}, Public: []frontend.Variable{1}}, // single commitment single public + &commitmentCircuit{X: []frontend.Variable{0, 1, 2, 3, 4}, Public: []frontend.Variable{1, 2, 3, 4, 5}}, // five commitments five public + &committedConstantCircuit{1}, // single committed constant + &committedPublicCircuit{1}, // single committed public + &independentCommitsCircuit{X: []frontend.Variable{1, 1}}, // two independent commitments + &twoCommitCircuit{X: []frontend.Variable{1, 2}, Y: 3}, // two commitments, second depending on first + &doubleCommitCircuit{X: 1, Y: 2}, // double committing to the same variable + } +} + +func TestCommitment(t *testing.T) { + t.Parallel() + + for _, assignment := range commitmentTestCircuits { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, WithBackends(backend.GROTH16, backend.PLONK)) + } +} + +func TestCommitmentDummySetup(t *testing.T) { + t.Parallel() + + run := func(assignment frontend.Circuit) func(t *testing.T) { + return func(t *testing.T) { + // just test the prover + _cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, hollow(assignment)) + require.NoError(t, err) + _r1cs := _cs.(*cs.R1CS) + var ( + dPk, pk groth16.ProvingKey + vk groth16.VerifyingKey + w witness.Witness + ) + require.NoError(t, groth16.Setup(_r1cs, &pk, &vk)) + require.NoError(t, groth16.DummySetup(_r1cs, &dPk)) + + comparePkSizes(t, dPk, pk) + + w, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField()) + require.NoError(t, err) + _, err = groth16.Prove(_r1cs, &pk, w) + require.NoError(t, err) + } + } + + for _, assignment := range commitmentTestCircuits { + name := removePackageName(reflect.TypeOf(assignment).String()) + if c, ok := assignment.(*commitmentCircuit); ok { + name += fmt.Sprintf(":%dprivate %dpublic", len(c.X), len(c.Public)) + } + t.Run(name, run(assignment)) + } +} + +func comparePkSizes(t *testing.T, pk1, pk2 groth16.ProvingKey) { + // skipping the domain + require.Equal(t, len(pk1.G1.A), len(pk2.G1.A)) + require.Equal(t, len(pk1.G1.B), len(pk2.G1.B)) + require.Equal(t, len(pk1.G1.Z), len(pk2.G1.Z)) + require.Equal(t, len(pk1.G1.K), len(pk2.G1.K)) + + require.Equal(t, len(pk1.G2.B), len(pk2.G2.B)) + + require.Equal(t, len(pk1.InfinityA), len(pk2.InfinityA)) + require.Equal(t, len(pk1.InfinityB), len(pk2.InfinityB)) + require.Equal(t, pk1.NbInfinityA, pk2.NbInfinityA) + require.Equal(t, pk1.NbInfinityB, pk2.NbInfinityB) + + require.Equal(t, len(pk1.CommitmentKeys), len(pk2.CommitmentKeys)) // TODO @Tabaie Compare the commitment keys +} diff --git a/test/end_to_end.go b/test/end_to_end.go new file mode 100644 index 0000000000..5da06b75c9 --- /dev/null +++ b/test/end_to_end.go @@ -0,0 +1,57 @@ +package test + +import ( + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "reflect" + "strings" + "testing" +) + +func makeOpts(opt TestingOption, curves []ecc.ID) []TestingOption { + if len(curves) > 0 { + return []TestingOption{opt, WithCurves(curves[0], curves[1:]...)} + } + return []TestingOption{opt} +} + +func testPlonk(t *testing.T, assignment frontend.Circuit, curves ...ecc.ID) { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, makeOpts(WithBackends(backend.PLONK), curves)...) +} + +func testGroth16(t *testing.T, assignment frontend.Circuit, curves ...ecc.ID) { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, makeOpts(WithBackends(backend.GROTH16), curves)...) +} + +func testAll(t *testing.T, assignment frontend.Circuit) { + NewAssert(t).ProverSucceeded(hollow(assignment), assignment, WithBackends(backend.GROTH16, backend.PLONK)) +} + +// hollow takes a gnark circuit and removes all the witness data. The resulting circuit can be used for compilation purposes +// Its purpose is to make testing more convenient. For example, as opposed to SolvingSucceeded(circuit, assignment), +// one can write SolvingSucceeded(hollow(assignment), assignment), obviating the creation of a separate circuit object. +func hollow(c frontend.Circuit) frontend.Circuit { + cV := reflect.ValueOf(c).Elem() + t := reflect.TypeOf(c).Elem() + res := reflect.New(t) // a new object of the same type as c + resE := res.Elem() + resC := res.Interface().(frontend.Circuit) + + frontendVar := reflect.TypeOf((*frontend.Variable)(nil)).Elem() + + for i := 0; i < t.NumField(); i++ { + fieldT := t.Field(i).Type + if fieldT.Kind() == reflect.Slice && fieldT.Elem().Implements(frontendVar) { // create empty slices for witness slices + resE.Field(i).Set(reflect.ValueOf(make([]frontend.Variable, cV.Field(i).Len()))) + } else if fieldT != frontendVar { // copy non-witness variables + resE.Field(i).Set(cV.Field(i)) + } + } + + return resC +} + +func removePackageName(s string) string { + return s[strings.LastIndex(s, ".")+1:] +} diff --git a/test/engine.go b/test/engine.go index ab459573df..afdcd1457c 100644 --- a/test/engine.go +++ b/test/engine.go @@ -18,22 +18,27 @@ package test import ( "fmt" + "github.com/consensys/gnark/constraint" "math/big" "path/filepath" "reflect" "runtime" "strconv" "strings" + "sync/atomic" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/logger" + "golang.org/x/crypto/sha3" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/field/pool" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" ) @@ -48,25 +53,15 @@ type engine struct { q *big.Int opt backend.ProverConfig // mHintsFunctions map[hint.ID]hintFunction - constVars bool - apiWrapper ApiWrapper + constVars bool + kvstore.Store + blueprints []constraint.Blueprint + internalVariables []*big.Int } // TestEngineOption defines an option for the test engine. type TestEngineOption func(e *engine) error -// ApiWrapper defines a function which wraps the API given to the circuit. -type ApiWrapper func(frontend.API) frontend.API - -// WithApiWrapper is a test engine option which which wraps the API before -// calling the Define method in circuit. If not set, then API is not wrapped. -func WithApiWrapper(wrapper ApiWrapper) TestEngineOption { - return func(e *engine) error { - e.apiWrapper = wrapper - return nil - } -} - // SetAllVariablesAsConstants is a test engine option which makes the calls to // IsConstant() and ConstantValue() always return true. If this test engine // option is not set, then all variables are considered as non-constant, @@ -99,10 +94,10 @@ func WithBackendProverOptions(opts ...backend.ProverOption) TestEngineOption { // This is an experimental feature. func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEngineOption) (err error) { e := &engine{ - curveID: utils.FieldToCurve(field), - q: new(big.Int).Set(field), - apiWrapper: func(a frontend.API) frontend.API { return a }, - constVars: false, + curveID: utils.FieldToCurve(field), + q: new(big.Int).Set(field), + constVars: false, + Store: kvstore.New(), } for _, opt := range opts { if err := opt(e); err != nil { @@ -131,8 +126,13 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng log := logger.Logger() log.Debug().Msg("running circuit in test engine") cptAdd, cptMul, cptSub, cptToBinary, cptFromBinary, cptAssertIsEqual = 0, 0, 0, 0, 0, 0 - api := e.apiWrapper(e) - err = c.Define(api) + if err = c.Define(e); err != nil { + return fmt.Errorf("define: %w", err) + } + if err = callDeferred(e); err != nil { + return fmt.Errorf("deferred: %w", err) + } + log.Debug().Uint64("add", cptAdd). Uint64("sub", cptSub). Uint64("mul", cptMul). @@ -143,14 +143,23 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng return } +func callDeferred(builder *engine) error { + for i := 0; i < len(circuitdefer.GetAll[func(frontend.API) error](builder)); i++ { + if err := circuitdefer.GetAll[func(frontend.API) error](builder)[i](builder); err != nil { + return fmt.Errorf("defer fn %d: %w", i, err) + } + } + return nil +} + var cptAdd, cptMul, cptSub, cptToBinary, cptFromBinary, cptAssertIsEqual uint64 func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - cptAdd++ + atomic.AddUint64(&cptAdd, 1) res := new(big.Int) res.Add(e.toBigInt(i1), e.toBigInt(i2)) for i := 0; i < len(in); i++ { - cptAdd++ + atomic.AddUint64(&cptAdd, 1) res.Add(res, e.toBigInt(in[i])) } res.Mod(res, e.modulus()) @@ -161,19 +170,20 @@ func (e *engine) MulAcc(a, b, c frontend.Variable) frontend.Variable { bc := pool.BigInt.Get() bc.Mul(e.toBigInt(b), e.toBigInt(c)) + res := new(big.Int) _a := e.toBigInt(a) - _a.Add(_a, bc).Mod(_a, e.modulus()) + res.Add(_a, bc).Mod(res, e.modulus()) pool.BigInt.Put(bc) - return _a + return res } func (e *engine) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - cptSub++ + atomic.AddUint64(&cptSub, 1) res := new(big.Int) res.Sub(e.toBigInt(i1), e.toBigInt(i2)) for i := 0; i < len(in); i++ { - cptSub++ + atomic.AddUint64(&cptSub, 1) res.Sub(res, e.toBigInt(in[i])) } res.Mod(res, e.modulus()) @@ -188,7 +198,7 @@ func (e *engine) Neg(i1 frontend.Variable) frontend.Variable { } func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - cptMul++ + atomic.AddUint64(&cptMul, 1) b2 := e.toBigInt(i2) if len(in) == 0 && b2.IsUint64() && b2.Uint64() <= 1 { // special path to avoid useless allocations @@ -202,7 +212,7 @@ func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend res.Mul(b1, b2) res.Mod(res, e.modulus()) for i := 0; i < len(in); i++ { - cptMul++ + atomic.AddUint64(&cptMul, 1) res.Mul(res, e.toBigInt(in[i])) res.Mod(res, e.modulus()) } @@ -242,7 +252,7 @@ func (e *engine) Inverse(i1 frontend.Variable) frontend.Variable { } func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { - cptToBinary++ + atomic.AddUint64(&cptToBinary, 1) nbBits := e.FieldBitLen() if len(n) == 1 { nbBits = n[0] @@ -274,7 +284,7 @@ func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { } func (e *engine) FromBinary(v ...frontend.Variable) frontend.Variable { - cptFromBinary++ + atomic.AddUint64(&cptFromBinary, 1) bits := make([]bool, len(v)) for i := 0; i < len(v); i++ { be := e.toBigInt(v[i]) @@ -371,7 +381,7 @@ func (e *engine) Cmp(i1, i2 frontend.Variable) frontend.Variable { } func (e *engine) AssertIsEqual(i1, i2 frontend.Variable) { - cptAssertIsEqual++ + atomic.AddUint64(&cptAssertIsEqual, 1) b1, b2 := e.toBigInt(i1), e.toBigInt(i2) if b1.Cmp(b2) != 0 { panic(fmt.Sprintf("[assertIsEqual] %s == %s", b1.String(), b2.String())) @@ -448,7 +458,7 @@ func (e *engine) print(sbb *strings.Builder, x interface{}) { } } -func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (e *engine) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") @@ -479,6 +489,14 @@ func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Vari return out, nil } +func (e *engine) NewHintForId(id solver.HintID, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + if f := solver.GetRegisteredHint(id); f != nil { + return e.NewHint(f, nbOutputs, inputs...) + } + + return nil, fmt.Errorf("no hint registered with id #%d. Use solver.RegisterHint or solver.RegisterNamedHint", id) +} + // IsConstant returns true if v is a constant known at compile time func (e *engine) IsConstant(v frontend.Variable) bool { return e.constVars @@ -514,7 +532,7 @@ func (e *engine) toBigInt(i1 frontend.Variable) *big.Int { } } -// bitLen returns the number of bits needed to represent a fr.Element +// FieldBitLen returns the number of bits needed to represent a fr.Element func (e *engine) FieldBitLen() int { return e.q.BitLen() } @@ -585,5 +603,82 @@ func (e *engine) Compiler() frontend.Compiler { } func (e *engine) Commit(v ...frontend.Variable) (frontend.Variable, error) { - panic("not implemented") + nb := (e.FieldBitLen() + 7) / 8 + buf := make([]byte, nb) + hasher := sha3.NewCShake128(nil, []byte("gnark test engine")) + for i := range v { + vs := e.toBigInt(v[i]) + bs := vs.FillBytes(buf) + hasher.Write(bs) + } + hasher.Read(buf) + res := new(big.Int).SetBytes(buf) + res.Mod(res, e.modulus()) + return res, nil +} + +func (e *engine) Defer(cb func(frontend.API) error) { + circuitdefer.Put(e, cb) +} + +// AddInstruction is used to add custom instructions to the constraint system. +// In constraint system, this is asynchronous. In here, we do it synchronously. +func (e *engine) AddInstruction(bID constraint.BlueprintID, calldata []uint32) []uint32 { + blueprint := e.blueprints[bID].(constraint.BlueprintSolvable) + + // create a dummy instruction + inst := constraint.Instruction{ + Calldata: calldata, + WireOffset: uint32(len(e.internalVariables)), + } + + // blueprint declared nbOutputs; add as many internal variables + // and return their indices + nbOutputs := blueprint.NbOutputs(inst) + var r []uint32 + for i := 0; i < nbOutputs; i++ { + r = append(r, uint32(len(e.internalVariables))) + e.internalVariables = append(e.internalVariables, new(big.Int)) + } + + // solve the blueprint synchronously + s := blueprintSolver{ + internalVariables: e.internalVariables, + q: e.q, + } + if err := blueprint.Solve(&s, inst); err != nil { + panic(err) + } + + return r +} + +// AddBlueprint adds a custom blueprint to the constraint system. +func (e *engine) AddBlueprint(b constraint.Blueprint) constraint.BlueprintID { + if _, ok := b.(constraint.BlueprintSolvable); !ok { + panic("unsupported blueprint in test engine") + } + e.blueprints = append(e.blueprints, b) + return constraint.BlueprintID(len(e.blueprints) - 1) +} + +// InternalVariable returns the value of an internal variable. This is used in custom blueprints. +// The variableID is the index of the variable in the internalVariables slice, as +// filled by AddInstruction. +func (e *engine) InternalVariable(vID uint32) frontend.Variable { + if vID >= uint32(len(e.internalVariables)) { + panic("internal variable not found") + } + return new(big.Int).Set(e.internalVariables[vID]) +} + +// ToCanonicalVariable converts a frontend.Variable to a frontend.CanonicalVariable +// this is used in custom blueprints to return a variable than can be encoded in blueprints +func (e *engine) ToCanonicalVariable(v frontend.Variable) frontend.CanonicalVariable { + r := e.toBigInt(v) + return wrappedBigInt{r} +} + +func (e *engine) SetGkrInfo(info constraint.GkrInfo) error { + return fmt.Errorf("not implemented") } diff --git a/test/engine_test.go b/test/engine_test.go index 26b232b46e..16ed830c12 100644 --- a/test/engine_test.go +++ b/test/engine_test.go @@ -6,7 +6,9 @@ import ( "testing" "github.com/consensys/gnark" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/bits" ) @@ -16,24 +18,24 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.Compiler().NewHint(bits.IthBit, 1, circuit.A, 3) + res, err := api.Compiler().NewHint(bits.GetHints()[0], 1, circuit.A, 3) if err != nil { return fmt.Errorf("IthBit circuitA 3: %w", err) } a3b := res[0] - res, err = api.Compiler().NewHint(bits.IthBit, 1, circuit.A, 25) + res, err = api.Compiler().NewHint(bits.GetHints()[0], 1, circuit.A, 25) if err != nil { return fmt.Errorf("IthBit circuitA 25: %w", err) } a25b := res[0] - res, err = api.Compiler().NewHint(hint.InvZero, 1, circuit.A) + res, err = api.Compiler().NewHint(solver.InvZeroHint, 1, circuit.A) if err != nil { return fmt.Errorf("IsZero CircuitA: %w", err) } aInvZero := res[0] - res, err = api.Compiler().NewHint(hint.InvZero, 1, circuit.B) + res, err = api.Compiler().NewHint(solver.InvZeroHint, 1, circuit.B) if err != nil { return fmt.Errorf("IsZero, CircuitB") } @@ -69,3 +71,33 @@ func TestBuiltinHints(t *testing.T) { } } + +var isDeferCalled bool + +type EmptyCircuit struct { + X frontend.Variable +} + +func (c *EmptyCircuit) Define(api frontend.API) error { + api.AssertIsEqual(c.X, 0) + api.Compiler().Defer(func(api frontend.API) error { + isDeferCalled = true + return nil + }) + return nil +} + +func TestPreCompileHook(t *testing.T) { + c := &EmptyCircuit{} + w := &EmptyCircuit{ + X: 0, + } + isDeferCalled = false + err := IsSolved(c, w, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } + if !isDeferCalled { + t.Error("callback not called") + } +} diff --git a/test/fuzz.go b/test/fuzz.go index a4c59eb69f..60a43a7471 100644 --- a/test/fuzz.go +++ b/test/fuzz.go @@ -74,7 +74,7 @@ func zeroFiller(w frontend.Circuit, curve ecc.ID) { } func binaryFiller(w frontend.Circuit, curve ecc.ID) { - mrand.Seed(time.Now().Unix()) + mrand := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here fill(w, func() interface{} { return int(mrand.Uint32() % 2) //#nosec G404 weak rng is fine here @@ -83,7 +83,7 @@ func binaryFiller(w frontend.Circuit, curve ecc.ID) { func seedFiller(w frontend.Circuit, curve ecc.ID) { - mrand.Seed(time.Now().Unix()) + mrand := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here m := curve.ScalarField() @@ -96,8 +96,6 @@ func seedFiller(w frontend.Circuit, curve ecc.ID) { func randomFiller(w frontend.Circuit, curve ecc.ID) { - mrand.Seed(time.Now().Unix()) - r := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here m := curve.ScalarField() diff --git a/test/hint_test.go b/test/hint_test.go new file mode 100644 index 0000000000..d2c6650f7f --- /dev/null +++ b/test/hint_test.go @@ -0,0 +1,58 @@ +package test + +import ( + "fmt" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "math/big" + "testing" +) + +const id = solver.HintID(123454321) + +func identityHint(_ *big.Int, in, out []*big.Int) error { + if len(in) != len(out) { + return fmt.Errorf("len(in) = %d ≠ %d = len(out)", len(in), len(out)) + } + for i := range in { + out[i].Set(in[i]) + } + return nil +} + +type customNamedHintCircuit struct { + X []frontend.Variable +} + +func (c *customNamedHintCircuit) Define(api frontend.API) error { + y, err := api.Compiler().NewHintForId(id, len(c.X), c.X...) + + if err != nil { + return err + } + for i := range y { + api.AssertIsEqual(c.X[i], y[i]) + } + + return nil +} + +var assignment customNamedHintCircuit + +func init() { + solver.RegisterNamedHint(identityHint, id) + assignment = customNamedHintCircuit{X: []frontend.Variable{1, 2, 3, 4, 5}} +} + +func TestHintWithCustomNamePlonk(t *testing.T) { + testPlonk(t, &assignment) +} + +func TestHintWithCustomNameGroth16(t *testing.T) { + testGroth16(t, &assignment) +} + +func TestHintWithCustomNameEngine(t *testing.T) { + circuit := hollow(&assignment) + NewAssert(t).SolvingSucceeded(circuit, &assignment) +} diff --git a/test/kzg_srs.go b/test/kzg_srs.go index 4002762919..c67b89be7f 100644 --- a/test/kzg_srs.go +++ b/test/kzg_srs.go @@ -51,7 +51,6 @@ func NewKZGSRS(ccs constraint.ConstraintSystem) (kzg.SRS, error) { } return newKZGSRS(utils.FieldToCurve(ccs.Field()), kzgSize) - } var srsCache map[ecc.ID]kzg.SRS diff --git a/test/options.go b/test/options.go index bcdae0eee8..65aa8a00c2 100644 --- a/test/options.go +++ b/test/options.go @@ -19,10 +19,11 @@ package test import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" ) -// TestingOption defines option for altering the behaviour of Assert methods. +// TestingOption defines option for altering the behavior of Assert methods. // See the descriptions of functions returning instances of this type for // particular options. type TestingOption func(*testingConfig) error @@ -31,9 +32,11 @@ type testingConfig struct { backends []backend.ID curves []ecc.ID witnessSerialization bool + solverOpts []solver.Option proverOpts []backend.ProverOption compileOpts []frontend.CompileOption fuzzing bool + solidity bool } // WithBackends is testing option which restricts the backends the assertions are @@ -83,6 +86,16 @@ func WithProverOpts(proverOpts ...backend.ProverOption) TestingOption { } } +// WithSolverOpts is a testing option which uses the given solverOpts when +// calling constraint system solver. +func WithSolverOpts(solverOpts ...solver.Option) TestingOption { + return func(opt *testingConfig) error { + opt.proverOpts = append(opt.proverOpts, backend.WithSolverOptions(solverOpts...)) + opt.solverOpts = solverOpts + return nil + } +} + // WithCompileOpts is a testing option which uses the given compileOpts when // calling frontend.Compile in assertions. func WithCompileOpts(compileOpts ...frontend.CompileOption) TestingOption { @@ -91,3 +104,16 @@ func WithCompileOpts(compileOpts ...frontend.CompileOption) TestingOption { return nil } } + +// WithSolidity is a testing option which enables solidity tests in assertions. +// If the build tag "solccheck" is not set, this option is ignored. +// When the tag is set; this requires gnark-solidity-checker to be installed, which in turns +// requires solc and abigen to be reachable in the PATH. +// +// See https://github.com/ConsenSys/gnark-solidity-checker for more details. +func WithSolidity() TestingOption { + return func(opt *testingConfig) error { + opt.solidity = true && solcCheck + return nil + } +} diff --git a/test/solccheck_with.go b/test/solccheck_with.go new file mode 100644 index 0000000000..9bf23579e1 --- /dev/null +++ b/test/solccheck_with.go @@ -0,0 +1,5 @@ +//go:build solccheck + +package test + +const solcCheck = true diff --git a/test/solccheck_without.go b/test/solccheck_without.go new file mode 100644 index 0000000000..a6414c27de --- /dev/null +++ b/test/solccheck_without.go @@ -0,0 +1,5 @@ +//go:build !solccheck + +package test + +const solcCheck = false diff --git a/test/solver_test.go b/test/solver_test.go index 0c11baed7b..88e0d39b92 100644 --- a/test/solver_test.go +++ b/test/solver_test.go @@ -1,23 +1,24 @@ package test import ( - "errors" "fmt" + "io" "math/big" "reflect" "strconv" "strings" "testing" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" - cs "github.com/consensys/gnark/constraint/tinyfield" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/internal/backend/circuits" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" ) @@ -25,6 +26,11 @@ import ( // ignore witness size larger than this bound const permutterBound = 3 +// r1cs + sparser1cs +const nbSystems = 2 + +var builders [2]frontend.NewBuilder + func TestSolverConsistency(t *testing.T) { if testing.Short() { t.Skip("skipping R1CS solver test with testing.Short() flag set") @@ -49,15 +55,59 @@ func TestSolverConsistency(t *testing.T) { } } -type permutter struct { - circuit frontend.Circuit - r1cs *cs.R1CS - scs *cs.SparseR1CS - witness []tinyfield.Element - hints []hint.Function +// witness used for the permutter. It implements the Witness interface +// using mock methods (only the underlying vector is required). +type permutterWitness struct { + vector any +} + +func (pw *permutterWitness) WriteTo(w io.Writer) (int64, error) { + return 0, nil +} + +func (pw *permutterWitness) ReadFrom(r io.Reader) (int64, error) { + return 0, nil +} + +func (pw *permutterWitness) MarshalBinary() ([]byte, error) { + return nil, nil +} - // used to avoid allocations in R1CS solver - a, b, c []tinyfield.Element +func (pw *permutterWitness) UnmarshalBinary([]byte) error { + return nil +} + +func (pw *permutterWitness) Public() (witness.Witness, error) { + return pw, nil +} + +func (pw *permutterWitness) Vector() any { + return pw.vector +} + +func (pw *permutterWitness) ToJSON(s *schema.Schema) ([]byte, error) { + return nil, nil +} + +func (pw *permutterWitness) FromJSON(s *schema.Schema, data []byte) error { + return nil +} + +func (pw *permutterWitness) Fill(nbPublic, nbSecret int, values <-chan any) error { + return nil +} + +func newPermutterWitness(pv tinyfield.Vector) witness.Witness { + return &permutterWitness{ + vector: pv, + } +} + +type permutter struct { + circuit frontend.Circuit + constraintSystems [2]constraint.ConstraintSystem + witness []tinyfield.Element + hints []solver.Hint } // note that circuit will be mutated and this is not thread safe @@ -66,30 +116,36 @@ func (p *permutter) permuteAndTest(index int) error { for i := 0; i < len(tinyfieldElements); i++ { p.witness[index].SetUint64(tinyfieldElements[i]) if index == len(p.witness)-1 { + // we have a unique permutation + var errorSystems [2]error + var errorEngines [2]error + + // 2 constraints systems + for k := 0; k < nbSystems; k++ { - // solve the cs using R1CS solver - errR1CS := p.solveR1CS() - errSCS := p.solveSCS() + errorSystems[k] = p.solve(k) - // solve the cs using test engine - // first copy the witness in the circuit - copyWitnessFromVector(p.circuit, p.witness) - errEngine1 := isSolvedEngine(p.circuit, tinyfield.Modulus()) + // solve the cs using test engine + // first copy the witness in the circuit + copyWitnessFromVector(p.circuit, p.witness) + errorEngines[0] = isSolvedEngine(p.circuit, tinyfield.Modulus()) - copyWitnessFromVector(p.circuit, p.witness) - errEngine2 := isSolvedEngine(p.circuit, tinyfield.Modulus(), SetAllVariablesAsConstants()) + copyWitnessFromVector(p.circuit, p.witness) + errorEngines[1] = isSolvedEngine(p.circuit, tinyfield.Modulus(), SetAllVariablesAsConstants()) - if (errR1CS == nil) != (errEngine1 == nil) || - (errSCS == nil) != (errEngine1 == nil) || - (errEngine1 == nil) != (errEngine2 == nil) { + } + if (errorSystems[0] == nil) != (errorEngines[0] == nil) || + (errorSystems[1] == nil) != (errorEngines[0] == nil) || + (errorEngines[0] == nil) != (errorEngines[1] == nil) { return fmt.Errorf("errSCS :%s\nerrR1CS :%s\nerrEngine(const=false): %s\nerrEngine(const=true): %s\nwitness: %s", - formatError(errSCS), - formatError(errR1CS), - formatError(errEngine1), - formatError(errEngine2), + formatError(errorSystems[0]), + formatError(errorSystems[1]), + formatError(errorEngines[0]), + formatError(errorEngines[1]), formatWitness(p.witness)) } + } else { // recurse if err := p.permuteAndTest(index + 1); err != nil { @@ -123,38 +179,19 @@ func formatWitness(witness []tinyfield.Element) string { return sbb.String() } -func (p *permutter) solveSCS() error { - opt, err := backend.NewProverConfig(backend.WithHints(p.hints...)) - if err != nil { - return err - } - - _, err = p.scs.Solve(p.witness, opt) - return err -} - -func (p *permutter) solveR1CS() error { - opt, err := backend.NewProverConfig(backend.WithHints(p.hints...)) - if err != nil { - return err - } - - for i := 0; i < len(p.r1cs.Constraints); i++ { - p.a[i].SetZero() - p.b[i].SetZero() - p.c[i].SetZero() - } - _, err = p.r1cs.Solve(p.witness, p.a, p.b, p.c, opt) +func (p *permutter) solve(i int) error { + pw := newPermutterWitness(p.witness) + _, err := p.constraintSystems[i].Solve(pw, solver.WithHints(p.hints...)) return err } // isSolvedEngine behaves like test.IsSolved except it doesn't clone the circuit func isSolvedEngine(c frontend.Circuit, field *big.Int, opts ...TestEngineOption) (err error) { e := &engine{ - curveID: utils.FieldToCurve(field), - q: new(big.Int).Set(field), - apiWrapper: func(a frontend.API) frontend.API { return a }, - constVars: false, + curveID: utils.FieldToCurve(field), + q: new(big.Int).Set(field), + constVars: false, + Store: kvstore.New(), } for _, opt := range opts { if err := opt(e); err != nil { @@ -168,8 +205,12 @@ func isSolvedEngine(c frontend.Circuit, field *big.Int, opts ...TestEngineOption } }() - api := e.apiWrapper(e) - err = c.Define(api) + if err = c.Define(e); err != nil { + return fmt.Errorf("define: %w", err) + } + if err = callDeferred(e); err != nil { + return fmt.Errorf("") + } return } @@ -180,7 +221,7 @@ func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { i := 0 schema.Walk(to, tVariable, func(f schema.LeafInfo, tInput reflect.Value) error { if f.Visibility == schema.Public { - tInput.Set(reflect.ValueOf((from[i]))) + tInput.Set(reflect.ValueOf(from[i])) i++ } return nil @@ -188,7 +229,7 @@ func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { schema.Walk(to, tVariable, func(f schema.LeafInfo, tInput reflect.Value) error { if f.Visibility == schema.Secret { - tInput.Set(reflect.ValueOf((from[i]))) + tInput.Set(reflect.ValueOf(from[i])) i++ } return nil @@ -198,41 +239,30 @@ func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { // ConsistentSolver solves given circuit with all possible witness combinations using internal/tinyfield // // Since the goal of this method is to flag potential solver issues, it is not exposed as an API for now -func consistentSolver(circuit frontend.Circuit, hintFunctions []hint.Function) error { +func consistentSolver(circuit frontend.Circuit, hintFunctions []solver.Hint) error { p := permutter{ circuit: circuit, hints: hintFunctions, } - // compile R1CS - ccs, err := frontend.Compile(tinyfield.Modulus(), r1cs.NewBuilder, circuit) - if err != nil { - return err - } + // compile the systems + for i := 0; i < nbSystems; i++ { - p.r1cs = ccs.(*cs.R1CS) - - // witness len - n := len(p.r1cs.Public) - 1 + len(p.r1cs.Secret) - if n > permutterBound { - return nil - } - - p.a = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) - p.b = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) - p.c = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) - p.witness = make([]tinyfield.Element, n) + ccs, err := frontend.Compile(tinyfield.Modulus(), builders[i], circuit) + if err != nil { + return err + } + p.constraintSystems[i] = ccs - // compile SparseR1CS - ccs, err = frontend.Compile(tinyfield.Modulus(), scs.NewBuilder, circuit) - if err != nil { - return err - } + if i == 0 { // the -1 is only for r1cs... + n := ccs.GetNbPublicVariables() - 1 + ccs.GetNbSecretVariables() + if n > permutterBound { + return nil + } + p.witness = make([]tinyfield.Element, n) + } - p.scs = ccs.(*cs.SparseR1CS) - if (len(p.scs.Public) + len(p.scs.Secret)) != n { - return errors.New("mismatch of witness size for same circuit") } return p.permuteAndTest(0) @@ -247,4 +277,7 @@ func init() { for i := uint64(0); i < n; i++ { tinyfieldElements[i] = i } + + builders[0] = r1cs.NewBuilder + builders[1] = scs.NewBuilder }