From 6d764cab85492283c346b69b3999978cb7cd9bf9 Mon Sep 17 00:00:00 2001 From: Ming Date: Wed, 26 Feb 2025 16:37:25 +0700 Subject: [PATCH] add whir as a crate in ceno via git subtree (#841) re-open of #840 due to the base branch change integrate whir as git subtree. ### why subtree - This enable us to plug `whir` as a crate in ceno, thus it can share most of libraries with ceno - review everything just in one ceno PR - retain approach sync back to upstream via command `git subtree push --prefix=whir https://github.com/scroll-tech/whir `. The timing to do that is decouple with our development in Ceno. We only need that when going to sync back to upstream With this PR, we can start to do refactor on whir, e.g. - migrate field to plonky3 - unify transcript - share MLE + sumcheck from Ceno - ... --------- Co-authored-by: Yuncong Zhang --- Cargo.lock | 1 - Cargo.toml | 1 + mpcs/Cargo.toml | 2 +- whir/.github/workflows/rust.yml | 24 + whir/.gitignore | 12 + whir/Cargo.lock | 1300 +++++++++++++++++ whir/Cargo.toml | 56 + whir/LICENSE-APACHE | 201 +++ whir/LICENSE-MIT | 21 + whir/README.md | 47 + whir/rust-toolchain.toml | 2 + whir/src/bin/benchmark.rs | 446 ++++++ whir/src/bin/main.rs | 438 ++++++ whir/src/ceno_binding/merkle_config.rs | 292 ++++ whir/src/ceno_binding/mod.rs | 80 + whir/src/ceno_binding/pcs.rs | 442 ++++++ whir/src/cmdline_utils.rs | 75 + whir/src/crypto/fields.rs | 92 ++ whir/src/crypto/merkle_tree/blake3.rs | 158 ++ whir/src/crypto/merkle_tree/keccak.rs | 158 ++ whir/src/crypto/merkle_tree/mock.rs | 67 + whir/src/crypto/merkle_tree/mod.rs | 73 + whir/src/crypto/mod.rs | 2 + whir/src/domain.rs | 144 ++ whir/src/fs_utils.rs | 38 + whir/src/lib.rs | 13 + whir/src/ntt/matrix.rs | 187 +++ whir/src/ntt/mod.rs | 61 + whir/src/ntt/ntt_impl.rs | 408 ++++++ whir/src/ntt/transpose.rs | 550 +++++++ whir/src/ntt/utils.rs | 143 ++ whir/src/ntt/wavelet.rs | 90 ++ whir/src/parameters.rs | 213 +++ whir/src/poly_utils/coeffs.rs | 467 ++++++ whir/src/poly_utils/evals.rs | 95 ++ whir/src/poly_utils/fold.rs | 223 +++ whir/src/poly_utils/gray_lag_poly.rs | 151 ++ whir/src/poly_utils/hypercube.rs | 40 + whir/src/poly_utils/mod.rs | 325 +++++ whir/src/poly_utils/sequential_lag_poly.rs | 174 +++ .../poly_utils/streaming_evaluation_helper.rs | 82 ++ whir/src/sumcheck/mod.rs | 321 ++++ whir/src/sumcheck/proof.rs | 101 ++ whir/src/sumcheck/prover_batched.rs | 307 ++++ whir/src/sumcheck/prover_core.rs | 151 ++ whir/src/sumcheck/prover_not_skipping.rs | 329 +++++ .../sumcheck/prover_not_skipping_batched.rs | 186 +++ whir/src/sumcheck/prover_single.rs | 299 ++++ whir/src/utils.rs | 152 ++ whir/src/whir/batch/committer.rs | 184 +++ whir/src/whir/batch/iopattern.rs | 124 ++ whir/src/whir/batch/mod.rs | 8 + whir/src/whir/batch/prover.rs | 540 +++++++ whir/src/whir/batch/utils.rs | 101 ++ whir/src/whir/batch/verifier.rs | 726 +++++++++ whir/src/whir/committer.rs | 128 ++ whir/src/whir/fs_utils.rs | 39 + whir/src/whir/iopattern.rs | 95 ++ whir/src/whir/mod.rs | 369 +++++ whir/src/whir/parameters.rs | 633 ++++++++ whir/src/whir/prover.rs | 456 ++++++ whir/src/whir/verifier.rs | 631 ++++++++ 62 files changed, 13272 insertions(+), 2 deletions(-) create mode 100644 whir/.github/workflows/rust.yml create mode 100644 whir/.gitignore create mode 100644 whir/Cargo.lock create mode 100644 whir/Cargo.toml create mode 100644 whir/LICENSE-APACHE create mode 100644 whir/LICENSE-MIT create mode 100644 whir/README.md create mode 100644 whir/rust-toolchain.toml create mode 100644 whir/src/bin/benchmark.rs create mode 100644 whir/src/bin/main.rs create mode 100644 whir/src/ceno_binding/merkle_config.rs create mode 100644 whir/src/ceno_binding/mod.rs create mode 100644 whir/src/ceno_binding/pcs.rs create mode 100644 whir/src/cmdline_utils.rs create mode 100644 whir/src/crypto/fields.rs create mode 100644 whir/src/crypto/merkle_tree/blake3.rs create mode 100644 whir/src/crypto/merkle_tree/keccak.rs create mode 100644 whir/src/crypto/merkle_tree/mock.rs create mode 100644 whir/src/crypto/merkle_tree/mod.rs create mode 100644 whir/src/crypto/mod.rs create mode 100644 whir/src/domain.rs create mode 100644 whir/src/fs_utils.rs create mode 100644 whir/src/lib.rs create mode 100644 whir/src/ntt/matrix.rs create mode 100644 whir/src/ntt/mod.rs create mode 100644 whir/src/ntt/ntt_impl.rs create mode 100644 whir/src/ntt/transpose.rs create mode 100644 whir/src/ntt/utils.rs create mode 100644 whir/src/ntt/wavelet.rs create mode 100644 whir/src/parameters.rs create mode 100644 whir/src/poly_utils/coeffs.rs create mode 100644 whir/src/poly_utils/evals.rs create mode 100644 whir/src/poly_utils/fold.rs create mode 100644 whir/src/poly_utils/gray_lag_poly.rs create mode 100644 whir/src/poly_utils/hypercube.rs create mode 100644 whir/src/poly_utils/mod.rs create mode 100644 whir/src/poly_utils/sequential_lag_poly.rs create mode 100644 whir/src/poly_utils/streaming_evaluation_helper.rs create mode 100644 whir/src/sumcheck/mod.rs create mode 100644 whir/src/sumcheck/proof.rs create mode 100644 whir/src/sumcheck/prover_batched.rs create mode 100644 whir/src/sumcheck/prover_core.rs create mode 100644 whir/src/sumcheck/prover_not_skipping.rs create mode 100644 whir/src/sumcheck/prover_not_skipping_batched.rs create mode 100644 whir/src/sumcheck/prover_single.rs create mode 100644 whir/src/utils.rs create mode 100644 whir/src/whir/batch/committer.rs create mode 100644 whir/src/whir/batch/iopattern.rs create mode 100644 whir/src/whir/batch/mod.rs create mode 100644 whir/src/whir/batch/prover.rs create mode 100644 whir/src/whir/batch/utils.rs create mode 100644 whir/src/whir/batch/verifier.rs create mode 100644 whir/src/whir/committer.rs create mode 100644 whir/src/whir/fs_utils.rs create mode 100644 whir/src/whir/iopattern.rs create mode 100644 whir/src/whir/mod.rs create mode 100644 whir/src/whir/parameters.rs create mode 100644 whir/src/whir/prover.rs create mode 100644 whir/src/whir/verifier.rs diff --git a/Cargo.lock b/Cargo.lock index 29a5ffb7f..a025f7e29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3153,7 +3153,6 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/whir?branch=feat%2Fceno-binding-batch#cc05cba75d96a5c3caa78e06a6279ca741954c2b" dependencies = [ "ark-crypto-primitives", "ark-ff", diff --git a/Cargo.toml b/Cargo.toml index 462acb8c5..615ed199f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "poseidon", "sumcheck", "transcript", + "whir", ] resolver = "2" diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 318ff6a31..54305ecc7 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -35,7 +35,7 @@ rand_chacha.workspace = true rayon = { workspace = true, optional = true } serde.workspace = true transcript = { path = "../transcript" } -whir = { git = "https://github.com/scroll-tech/whir", branch = "feat/ceno-binding-batch", features = ["ceno"] } +whir = { path = "../whir", features = ["ceno"] } zeroize = "1.8" [dev-dependencies] diff --git a/whir/.github/workflows/rust.yml b/whir/.github/workflows/rust.yml new file mode 100644 index 000000000..11b7199e8 --- /dev/null +++ b/whir/.github/workflows/rust.yml @@ -0,0 +1,24 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Switch toolchain + run: rustup update nightly && rustup default nightly + - name: Build + run: cargo build --release --verbose + - name: Run tests + run: cargo test --release --verbose diff --git a/whir/.gitignore b/whir/.gitignore new file mode 100644 index 000000000..83fff18b6 --- /dev/null +++ b/whir/.gitignore @@ -0,0 +1,12 @@ +/target +scripts/temp/ +bench_utils/target +*_proof +artifacts +outputs/temp/ +*.pdf +scripts/__pycache__/ +.DS_Store +outputs/ +.idea +.vscode \ No newline at end of file diff --git a/whir/Cargo.lock b/whir/Cargo.lock new file mode 100644 index 000000000..c3207f271 --- /dev/null +++ b/whir/Cargo.lock @@ -0,0 +1,1300 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "allocator-api2" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" + +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys", +] + +[[package]] +name = "ark-crypto-primitives" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0c292754729c8a190e50414fd1a37093c786c709899f29c9f7daccecfa855e" +dependencies = [ + "ahash", + "ark-crypto-primitives-macros", + "ark-ec", + "ark-ff", + "ark-relations", + "ark-serialize", + "ark-snark", + "ark-std", + "blake2", + "derivative", + "digest", + "fnv", + "hashbrown 0.14.5", + "merlin", + "rayon", + "sha2", +] + +[[package]] +name = "ark-crypto-primitives-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7e89fe77d1f0f4fe5b96dfc940923d88d17b6a773808124f21e764dfb063c6a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown 0.15.1", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "rayon", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "rayon", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown 0.15.1", + "rayon", +] + +[[package]] +name = "ark-relations" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec46ddc93e7af44bcab5230937635b06fb5744464dd6a7e7b083e80ebd274384" +dependencies = [ + "ark-ff", + "ark-std", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "arrayvec", + "digest", + "num-bigint", + "rayon", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-snark" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d368e2848c2d4c129ce7679a7d0d2d612b6a274d3ea6a13bad4445d61b381b88" +dependencies = [ + "ark-ff", + "ark-relations", + "ark-serialize", + "ark-std", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "colored", + "num-traits", + "rand", + "rayon", +] + +[[package]] +name = "ark-test-curves" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc137bb3271671a597ae79157030f88affe9fa7975207c3b781a01cb9ed0372" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-std", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bytemuck" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cc" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clap" +version = "4.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "clap_lex" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + +[[package]] +name = "colored" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys", +] + +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + +[[package]] +name = "cpufeatures" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +dependencies = [ + "libc", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "unicode-xid", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "goldilocks" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno-Goldilocks#29a15d186ce4375dab346a3cc9eca6e43540cb8d" +dependencies = [ + "ff", + "halo2curves", + "itertools 0.12.1", + "rand_core", + "serde", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core", + "subtle", +] + +[[package]] +name = "halo2curves" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b1142bd1059aacde1b477e0c80c142910f1ceae67fc619311d6a17428007ab" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "num-bigint", + "num-traits", + "pasta_curves", + "paste", + "rand", + "rand_core", + "static_assertions", + "subtle", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "allocator-api2", +] + +[[package]] +name = "hashbrown" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +dependencies = [ + "allocator-api2", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] + +[[package]] +name = "libc" +version = "0.2.162" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "merlin" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d" +dependencies = [ + "byteorder", + "keccak", + "rand_core", + "zeroize", +] + +[[package]] +name = "nimue" +version = "0.2.0" +source = "git+https://github.com/arkworks-rs/nimue#b28eb124420eede1a890e3c64b37adfff94938d3" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "digest", + "generic-array", + "hex", + "keccak", + "log", + "rand", + "zeroize", +] + +[[package]] +name = "nimue-pow" +version = "0.1.0" +source = "git+https://github.com/arkworks-rs/nimue#b28eb124420eede1a890e3c64b37adfff94938d3" +dependencies = [ + "blake3", + "bytemuck", + "keccak", + "nimue", + "rayon", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "rand", + "static_assertions", + "subtle", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "serde_json" +version = "1.0.132" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] + +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "whir" +version = "0.1.0" +dependencies = [ + "ark-crypto-primitives", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "ark-test-curves", + "blake2", + "blake3", + "clap", + "derivative", + "derive_more", + "goldilocks", + "itertools 0.14.0", + "lazy_static", + "nimue", + "nimue-pow", + "rand", + "rand_chacha", + "rayon", + "serde", + "serde_json", + "sha3", + "thiserror", + "transpose", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] diff --git a/whir/Cargo.toml b/whir/Cargo.toml new file mode 100644 index 000000000..66124c294 --- /dev/null +++ b/whir/Cargo.toml @@ -0,0 +1,56 @@ +[package] +categories = ["cryptography", "zk", "blockchain", "pcs"] +description = "Multilinear Polynomial Commitment Scheme" +edition = "2021" +keywords = ["cryptography", "zk", "blockchain", "pcs"] +license = "MIT OR Apache-2.0" +name = "whir" +readme = "README.md" +repository = "https://github.com/WizardOfMenlo/whir/" +version = "0.1.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +default-run = "main" + +[dependencies] +ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } +ark-ff = { version = "0.5", features = ["asm", "std"] } +ark-poly = "0.5" +ark-serialize = "0.5" +ark-std = { version = "0.5", features = ["std"] } +ark-test-curves = { version = "0.5", features = ["bls12_381_curve"] } +blake2 = "0.10" +blake3 = "1.5.0" +clap = { version = "4.4.17", features = ["derive"] } +derivative = { version = "2", features = ["use_core"] } +lazy_static = "1.4" +nimue = { git = "https://github.com/arkworks-rs/nimue", features = ["ark"] } +nimue-pow = { git = "https://github.com/arkworks-rs/nimue" } +rand = "0.8" +rand_chacha = "0.3" +rayon = { version = "1.10.0", optional = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +sha3 = "0.10" +transpose = "0.2.3" + +derive_more = { version = "1.0.0", features = ["debug"] } +goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } +itertools = "0.14.0" +thiserror = "1" + +[profile.release] +debug = true + +[features] +asm = [] +ceno = [] +default = ["parallel", "ceno"] +parallel = [ + "dep:rayon", + "ark-poly/parallel", + "ark-ff/parallel", + "ark-crypto-primitives/parallel", +] +print-trace = ["ark-std/print-trace"] +rayon = ["dep:rayon"] diff --git a/whir/LICENSE-APACHE b/whir/LICENSE-APACHE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/whir/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/whir/LICENSE-MIT b/whir/LICENSE-MIT new file mode 100644 index 000000000..c67929c8b --- /dev/null +++ b/whir/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Gal Arnon, Alessandro Chiesa, Giacomo Fenzi, Eylon Yogev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/whir/README.md b/whir/README.md new file mode 100644 index 000000000..ef91b05da --- /dev/null +++ b/whir/README.md @@ -0,0 +1,47 @@ +

WHIR 🌪️

+ +This library was developed using the [arkworks](https://arkworks.rs) ecosystem to accompany [WHIR 🌪️](https://eprint.iacr.org/2024/1586). +By [Gal Arnon](https://galarnon42.github.io/) [Alessandro Chiesa](https://ic-people.epfl.ch/~achiesa/), [Giacomo Fenzi](https://gfenzi.io), and [Eylon Yogev](https://www.eylonyogev.com/about). + +**WARNING:** This is an academic prototype and has not received careful code review. This implementation is NOT ready for production use. + +

+ + +

+ +# Usage +``` +cargo run --release -- --help + +Usage: main [OPTIONS] + +Options: + -t, --type [default: PCS] + -l, --security-level [default: 100] + -p, --pow-bits + -d, --num-variables [default: 20] + -e, --evaluations [default: 1] + -r, --rate [default: 1] + --reps [default: 1000] + -k, --fold [default: 4] + --sec [default: ConjectureList] + --fold_type [default: ProverHelps] + -f, --field [default: Goldilocks2] + --hash [default: Blake3] + -h, --help Print help + -V, --version Print version +``` + +Options: +- `-t` can be either `PCS` or `LDT` to run as a (multilinear) PCS or a LDT +- `-l` sets the (overall) security level of the scheme +- `-p` sets the number of PoW bits (used for the query-phase). PoW bits for proximity gaps are set automatically. +- `-d` sets the number of variables of the scheme. +- `-e` sets the number of evaluations to prove. Only meaningful in PCS mode. +- `-r` sets the log_inv of the rate +- `-k` sets the number of variables to fold at each iteration. +- `--sec` sets the settings used to compute security. Available `UniqueDecoding`, `ProvableList`, `ConjectureList` +- `--fold_type` sets the settings used to compute folds. Available `Naive`, `ProverHelps` +- `-f` sets the field used, available are `Goldilocks2, Goldilocks3, Field192, Field256`. +- `--hash` sets the hash used for the Merkle tree, available are `SHA3` and `Blake3` diff --git a/whir/rust-toolchain.toml b/whir/rust-toolchain.toml new file mode 100644 index 000000000..5d1274a3e --- /dev/null +++ b/whir/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2024-10-03" diff --git a/whir/src/bin/benchmark.rs b/whir/src/bin/benchmark.rs new file mode 100644 index 000000000..a9c388124 --- /dev/null +++ b/whir/src/bin/benchmark.rs @@ -0,0 +1,446 @@ +use std::{ + fs::OpenOptions, + time::{Duration, Instant}, +}; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, +}; +use ark_ff::{FftField, Field}; +use ark_serialize::CanonicalSerialize; +use nimue::{Arthur, DefaultHash, IOPattern, Merlin}; +use nimue_pow::blake3::Blake3PoW; +use whir::{ + cmdline_utils::{AvailableFields, AvailableMerkle}, + crypto::{ + fields, + merkle_tree::{self, HashCounter}, + }, + parameters::*, + poly_utils::coeffs::CoefficientList, + whir::Statement, +}; + +use serde::Serialize; + +use clap::Parser; +use whir::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short = 'l', long, default_value = "100")] + security_level: usize, + + #[arg(short = 'p', long)] + pow_bits: Option, + + #[arg(short = 'd', long, default_value = "20")] + num_variables: usize, + + #[arg(short = 'e', long = "evaluations", default_value = "1")] + num_evaluations: usize, + + #[arg(short = 'r', long, default_value = "1")] + rate: usize, + + #[arg(long = "reps", default_value = "1000")] + verifier_repetitions: usize, + + #[arg(short = 'i', long = "initfold", default_value = "4")] + first_round_folding_factor: usize, + + #[arg(short = 'k', long = "fold", default_value = "4")] + folding_factor: usize, + + #[arg(long = "sec", default_value = "ConjectureList")] + soundness_type: SoundnessType, + + #[arg(long = "fold_type", default_value = "ProverHelps")] + fold_optimisation: FoldType, + + #[arg(short = 'f', long = "field", default_value = "Goldilocks2")] + field: AvailableFields, + + #[arg(long = "hash", default_value = "Blake3")] + merkle_tree: AvailableMerkle, +} + +#[derive(Debug, Serialize)] +struct BenchmarkOutput { + security_level: usize, + pow_bits: usize, + starting_rate: usize, + num_variables: usize, + repetitions: usize, + folding_factor: usize, + soundness_type: SoundnessType, + field: AvailableFields, + merkle_tree: AvailableMerkle, + + // Whir + whir_evaluations: usize, + whir_argument_size: usize, + whir_prover_time: Duration, + whir_prover_hashes: usize, + whir_verifier_time: Duration, + whir_verifier_hashes: usize, + + // Whir LDT + whir_ldt_argument_size: usize, + whir_ldt_prover_time: Duration, + whir_ldt_prover_hashes: usize, + whir_ldt_verifier_time: Duration, + whir_ldt_verifier_hashes: usize, +} + +type PowStrategy = Blake3PoW; + +fn main() { + let mut args = Args::parse(); + let field = args.field; + let merkle = args.merkle_tree; + + if args.pow_bits.is_none() { + args.pow_bits = Some(default_max_pow(args.num_variables, args.rate)); + } + + let mut rng = ark_std::test_rng(); + + match (field, merkle) { + (AvailableFields::Goldilocks1, AvailableMerkle::Blake3) => { + use fields::Field64 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks1, AvailableMerkle::Keccak256) => { + use fields::Field64 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Blake3) => { + use fields::Field64_2 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Keccak256) => { + use fields::Field64_2 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Blake3) => { + use fields::Field64_3 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Keccak256) => { + use fields::Field64_3 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Blake3) => { + use fields::Field128 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Keccak256) => { + use fields::Field128 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Blake3) => { + use fields::Field192 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Keccak256) => { + use fields::Field192 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Blake3) => { + use fields::Field256 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Keccak256) => { + use fields::Field256 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + } +} + +fn run_whir( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + let security_level = args.security_level; + let pow_bits = args.pow_bits.unwrap(); + let num_variables = args.num_variables; + let starting_rate = args.rate; + let reps = args.verifier_repetitions; + let folding_factor = args.folding_factor; + let first_round_folding_factor = args.first_round_folding_factor; + let soundness_type = args.soundness_type; + let fold_optimisation = args.fold_optimisation; + + std::fs::create_dir_all("outputs").unwrap(); + + let num_coeffs = 1 << num_variables; + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level, + pow_bits, + folding_factor: FoldingFactor::ConstantFromSecondRound( + first_round_folding_factor, + folding_factor, + ), + leaf_hash_params, + two_to_one_params, + soundness_type, + fold_optimisation, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + let polynomial = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + + let ( + whir_ldt_prover_time, + whir_ldt_argument_size, + whir_ldt_prover_hashes, + whir_ldt_verifier_time, + whir_ldt_verifier_hashes, + ) = { + // Run LDT + use whir::whir::{ + committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, + verifier::Verifier, whir_proof_size, + }; + + let whir_params = WhirParameters:: { + initial_statement: false, + ..whir_params.clone() + }; + let params = + WhirConfig::::new(mv_params, whir_params.clone()); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + let whir_ldt_prover_time = Instant::now(); + + HashCounter::reset(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial.clone()).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, Statement::default(), witness) + .unwrap(); + + let whir_ldt_prover_time = whir_ldt_prover_time.elapsed(); + let whir_ldt_argument_size = whir_proof_size(merlin.transcript(), &proof); + let whir_ldt_prover_hashes = HashCounter::get(); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params); + + HashCounter::reset(); + let whir_ldt_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(merlin.transcript()); + verifier + .verify(&mut arthur, &Statement::default(), &proof) + .unwrap(); + } + + let whir_ldt_verifier_time = whir_ldt_verifier_time.elapsed(); + let whir_ldt_verifier_hashes = HashCounter::get() / reps; + + ( + whir_ldt_prover_time, + whir_ldt_argument_size, + whir_ldt_prover_hashes, + whir_ldt_verifier_time, + whir_ldt_verifier_hashes, + ) + }; + + let ( + whir_prover_time, + whir_argument_size, + whir_prover_hashes, + whir_verifier_time, + whir_verifier_hashes, + ) = { + // Run PCS + use whir::{ + poly_utils::MultilinearPoint, + whir::{ + committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, + prover::Prover, verifier::Verifier, whir_proof_size, + }, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + let points: Vec<_> = (0..args.num_evaluations) + .map(|i| MultilinearPoint(vec![F::from(i as u64); num_variables])) + .collect(); + let evaluations = points + .iter() + .map(|point| polynomial.evaluate_at_extension(point)) + .collect(); + let statement = Statement { + points, + evaluations, + }; + + HashCounter::reset(); + let whir_prover_time = Instant::now(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial.clone()).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, statement.clone(), witness) + .unwrap(); + + let whir_prover_time = whir_prover_time.elapsed(); + let whir_argument_size = whir_proof_size(merlin.transcript(), &proof); + let whir_prover_hashes = HashCounter::get(); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params); + + HashCounter::reset(); + let whir_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(merlin.transcript()); + verifier.verify(&mut arthur, &statement, &proof).unwrap(); + } + + let whir_verifier_time = whir_verifier_time.elapsed(); + let whir_verifier_hashes = HashCounter::get() / reps; + + ( + whir_prover_time, + whir_argument_size, + whir_prover_hashes, + whir_verifier_time, + whir_verifier_hashes, + ) + }; + + let output = BenchmarkOutput { + security_level, + pow_bits, + starting_rate, + num_variables, + repetitions: reps, + folding_factor, + soundness_type, + field: args.field, + merkle_tree: args.merkle_tree, + + // Whir + whir_evaluations: args.num_evaluations, + whir_prover_time, + whir_argument_size, + whir_prover_hashes, + whir_verifier_time, + whir_verifier_hashes, + + // Whir LDT + whir_ldt_prover_time, + whir_ldt_argument_size, + whir_ldt_prover_hashes, + whir_ldt_verifier_time, + whir_ldt_verifier_hashes, + }; + + let mut out_file = OpenOptions::new() + .append(true) + .create(true) + .open("outputs/bench_output.json") + .unwrap(); + use std::io::Write; + writeln!(out_file, "{}", serde_json::to_string(&output).unwrap()).unwrap(); +} diff --git a/whir/src/bin/main.rs b/whir/src/bin/main.rs new file mode 100644 index 000000000..0fc804d3a --- /dev/null +++ b/whir/src/bin/main.rs @@ -0,0 +1,438 @@ +use std::time::Instant; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, +}; +use ark_ff::FftField; +use ark_serialize::CanonicalSerialize; +use nimue::{Arthur, DefaultHash, IOPattern, Merlin}; +use whir::{ + cmdline_utils::{AvailableFields, AvailableMerkle, WhirType}, + crypto::{ + fields, + merkle_tree::{self, HashCounter}, + }, + parameters::*, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::Statement, +}; + +use nimue_pow::blake3::Blake3PoW; + +use clap::Parser; +use whir::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short = 't', long = "type", default_value = "PCS")] + protocol_type: WhirType, + + #[arg(short = 'l', long, default_value = "100")] + security_level: usize, + + #[arg(short = 'p', long)] + pow_bits: Option, + + #[arg(short = 'd', long, default_value = "20")] + num_variables: usize, + + #[arg(short = 'e', long = "evaluations", default_value = "1")] + num_evaluations: usize, + + #[arg(short = 'r', long, default_value = "1")] + rate: usize, + + #[arg(long = "reps", default_value = "1000")] + verifier_repetitions: usize, + + #[arg(short = 'i', long = "initfold", default_value = "4")] + first_round_folding_factor: usize, + + #[arg(short = 'k', long = "fold", default_value = "4")] + folding_factor: usize, + + #[arg(long = "sec", default_value = "ConjectureList")] + soundness_type: SoundnessType, + + #[arg(long = "fold_type", default_value = "ProverHelps")] + fold_optimisation: FoldType, + + #[arg(short = 'f', long = "field", default_value = "Goldilocks2")] + field: AvailableFields, + + #[arg(long = "hash", default_value = "Blake3")] + merkle_tree: AvailableMerkle, +} + +type PowStrategy = Blake3PoW; + +fn main() { + let mut args = Args::parse(); + let field = args.field; + let merkle = args.merkle_tree; + + if args.pow_bits.is_none() { + args.pow_bits = Some(default_max_pow(args.num_variables, args.rate)); + } + + let mut rng = ark_std::test_rng(); + + match (field, merkle) { + (AvailableFields::Goldilocks1, AvailableMerkle::Blake3) => { + use fields::Field64 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks1, AvailableMerkle::Keccak256) => { + use fields::Field64 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Blake3) => { + use fields::Field64_2 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Keccak256) => { + use fields::Field64_2 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Blake3) => { + use fields::Field64_3 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Keccak256) => { + use fields::Field64_3 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Blake3) => { + use fields::Field128 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Keccak256) => { + use fields::Field128 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Blake3) => { + use fields::Field192 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Keccak256) => { + use fields::Field192 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Blake3) => { + use fields::Field256 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Keccak256) => { + use fields::Field256 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + } +} + +fn run_whir( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + match args.protocol_type { + WhirType::PCS => run_whir_pcs::(args, leaf_hash_params, two_to_one_params), + WhirType::LDT => { + run_whir_as_ldt::(args, leaf_hash_params, two_to_one_params) + } + } +} + +fn run_whir_as_ldt( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + use whir::whir::{ + committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, + verifier::Verifier, + }; + + // Runs as a LDT + let security_level = args.security_level; + let pow_bits = args.pow_bits.unwrap(); + let num_variables = args.num_variables; + let starting_rate = args.rate; + let reps = args.verifier_repetitions; + let first_round_folding_factor = args.first_round_folding_factor; + let folding_factor = args.folding_factor; + let fold_optimisation = args.fold_optimisation; + let soundness_type = args.soundness_type; + + if args.num_evaluations > 1 { + println!("Warning: running as LDT but a number of evaluations to be proven was specified."); + } + + let num_coeffs = 1 << num_variables; + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: false, + security_level, + pow_bits, + folding_factor: FoldingFactor::ConstantFromSecondRound( + first_round_folding_factor, + folding_factor, + ), + leaf_hash_params, + two_to_one_params, + soundness_type, + fold_optimisation, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + let params = WhirConfig::::new(mv_params, whir_params.clone()); + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms); + + let mut merlin = io.to_merlin(); + + println!("========================================="); + println!("Whir (LDT) 🌪️"); + println!("Field: {:?} and MT: {:?}", args.field, args.merkle_tree); + println!("{}", params); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + use ark_ff::Field; + let polynomial = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + + let whir_prover_time = Instant::now(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, Statement::default(), witness) + .unwrap(); + + dbg!(whir_prover_time.elapsed()); + + // Serialize proof + let transcript = merlin.transcript().to_vec(); + let mut proof_bytes = vec![]; + proof.serialize_compressed(&mut proof_bytes).unwrap(); + + let proof_size = transcript.len() + proof_bytes.len(); + dbg!(proof_size); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params.clone()); + + HashCounter::reset(); + let whir_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(&transcript); + verifier + .verify(&mut arthur, &Statement::default(), &proof) + .unwrap(); + } + dbg!(whir_verifier_time.elapsed() / reps as u32); + dbg!(HashCounter::get() as f64 / reps as f64); +} + +fn run_whir_pcs( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + use whir::whir::{ + Statement, committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, + prover::Prover, verifier::Verifier, whir_proof_size, + }; + + // Runs as a PCS + let security_level = args.security_level; + let pow_bits = args.pow_bits.unwrap(); + let num_variables = args.num_variables; + let starting_rate = args.rate; + let reps = args.verifier_repetitions; + let first_round_folding_factor = args.first_round_folding_factor; + let folding_factor = args.folding_factor; + let fold_optimisation = args.fold_optimisation; + let soundness_type = args.soundness_type; + let num_evaluations = args.num_evaluations; + + if num_evaluations == 0 { + println!("Warning: running as PCS but no evaluations specified."); + } + + let num_coeffs = 1 << num_variables; + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level, + pow_bits, + folding_factor: FoldingFactor::ConstantFromSecondRound( + first_round_folding_factor, + folding_factor, + ), + leaf_hash_params, + two_to_one_params, + soundness_type, + fold_optimisation, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + println!("========================================="); + println!("Whir (PCS) 🌪️"); + println!("Field: {:?} and MT: {:?}", args.field, args.merkle_tree); + println!("{}", params); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + use ark_ff::Field; + let polynomial = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + let points: Vec<_> = (0..num_evaluations) + .map(|i| MultilinearPoint(vec![F::from(i as u64); num_variables])) + .collect(); + let evaluations = points + .iter() + .map(|point| polynomial.evaluate_at_extension(point)) + .collect(); + + let statement = Statement { + points, + evaluations, + }; + + let whir_prover_time = Instant::now(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, statement.clone(), witness) + .unwrap(); + + println!("Prover time: {:.1?}", whir_prover_time.elapsed()); + println!( + "Proof size: {:.1} KiB", + whir_proof_size(merlin.transcript(), &proof) as f64 / 1024.0 + ); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params); + + HashCounter::reset(); + let whir_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(merlin.transcript()); + verifier.verify(&mut arthur, &statement, &proof).unwrap(); + } + println!( + "Verifier time: {:.1?}", + whir_verifier_time.elapsed() / reps as u32 + ); + println!( + "Average hashes: {:.1}k", + (HashCounter::get() as f64 / reps as f64) / 1000.0 + ); +} diff --git a/whir/src/ceno_binding/merkle_config.rs b/whir/src/ceno_binding/merkle_config.rs new file mode 100644 index 000000000..84eb804da --- /dev/null +++ b/whir/src/ceno_binding/merkle_config.rs @@ -0,0 +1,292 @@ +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use nimue::{Arthur, DefaultHash, IOPattern, Merlin, ProofResult}; +use nimue_pow::PowStrategy; + +use crate::{ + crypto::merkle_tree::{ + blake3::MerkleTreeParams as Blake3Params, keccak::MerkleTreeParams as KeccakParams, + }, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::{ + Statement, WhirProof, + batch::{WhirBatchIOPattern, Witnesses}, + committer::{Committer, Witness}, + fs_utils::DigestWriter, + iopattern::WhirIOPattern, + parameters::WhirConfig, + prover::Prover, + verifier::Verifier, + }, +}; + +pub trait WhirMerkleConfigWrapper { + type MerkleConfig: Config + Clone; + type PowStrategy: PowStrategy; + + fn commit_to_merlin( + committer: &Committer, + merlin: &mut Merlin, + poly: CoefficientList, + ) -> ProofResult>; + + fn commit_to_merlin_batch( + committer: &Committer, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult>; + + fn prove_with_merlin( + prover: &Prover, + merlin: &mut Merlin, + statement: Statement, + witness: Witness, + ) -> ProofResult>; + + fn prove_with_merlin_simple_batch( + prover: &Prover, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: &Witnesses, + ) -> ProofResult>; + + fn verify_with_arthur( + verifier: &Verifier, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest>; + + fn verify_with_arthur_simple_batch( + verifier: &Verifier, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest>; + + fn commit_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern; + + fn add_whir_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern; + + fn commit_batch_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern; + + fn add_whir_batch_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern; + + fn add_digest_to_merlin( + merlin: &mut Merlin, + digest: ::InnerDigest, + ) -> ProofResult<()>; +} + +pub struct Blake3ConfigWrapper(Blake3Params); +pub struct KeccakConfigWrapper(KeccakParams); + +impl WhirMerkleConfigWrapper for Blake3ConfigWrapper { + type MerkleConfig = Blake3Params; + type PowStrategy = nimue_pow::blake3::Blake3PoW; + fn commit_to_merlin( + committer: &Committer, + merlin: &mut Merlin, + poly: CoefficientList, + ) -> ProofResult> { + committer.commit(merlin, poly) + } + + fn commit_to_merlin_batch( + committer: &Committer, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> { + committer.batch_commit(merlin, polys) + } + + fn prove_with_merlin( + prover: &Prover, + merlin: &mut Merlin, + statement: Statement, + witness: Witness, + ) -> ProofResult> { + prover.prove(merlin, statement, witness) + } + + fn prove_with_merlin_simple_batch( + prover: &Prover, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: &Witnesses, + ) -> ProofResult> { + let points = [MultilinearPoint(point.to_vec())]; + prover.simple_batch_prove(merlin, &points, &[evals.to_vec()], witness) + } + + fn verify_with_arthur( + verifier: &Verifier, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + verifier.verify(arthur, statement, whir_proof) + } + + fn verify_with_arthur_simple_batch( + verifier: &Verifier, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + let points = [MultilinearPoint(point.to_vec())]; + verifier.simple_batch_verify(arthur, evals.len(), &points, &[evals.to_vec()], whir_proof) + } + + fn commit_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.commit_statement(params) + } + + fn add_whir_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.add_whir_proof(params) + } + + fn commit_batch_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.commit_batch_statement(params, batch_size) + } + + fn add_whir_batch_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.add_whir_batch_proof(params, batch_size) + } + + fn add_digest_to_merlin( + merlin: &mut Merlin, + digest: ::InnerDigest, + ) -> ProofResult<()> { + >::add_digest(merlin, digest) + } +} + +impl WhirMerkleConfigWrapper for KeccakConfigWrapper { + type MerkleConfig = KeccakParams; + type PowStrategy = nimue_pow::keccak::KeccakPoW; + fn commit_to_merlin( + committer: &Committer, + merlin: &mut Merlin, + poly: CoefficientList, + ) -> ProofResult> { + committer.commit(merlin, poly) + } + + fn commit_to_merlin_batch( + committer: &Committer, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> { + committer.batch_commit(merlin, polys) + } + + fn prove_with_merlin( + prover: &Prover, + merlin: &mut Merlin, + statement: Statement, + witness: Witness, + ) -> ProofResult> { + prover.prove(merlin, statement, witness) + } + + fn prove_with_merlin_simple_batch( + prover: &Prover, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: &Witnesses, + ) -> ProofResult> { + let points = [MultilinearPoint(point.to_vec())]; + prover.simple_batch_prove(merlin, &points, &[evals.to_vec()], witness) + } + + fn verify_with_arthur( + verifier: &Verifier, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + verifier.verify(arthur, statement, whir_proof) + } + + fn verify_with_arthur_simple_batch( + verifier: &Verifier, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + let points = [MultilinearPoint(point.to_vec())]; + verifier.simple_batch_verify(arthur, evals.len(), &points, &[evals.to_vec()], whir_proof) + } + + fn commit_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.commit_statement(params) + } + + fn add_whir_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.add_whir_proof(params) + } + + fn commit_batch_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.commit_batch_statement(params, batch_size) + } + + fn add_whir_batch_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.add_whir_batch_proof(params, batch_size) + } + + fn add_digest_to_merlin( + merlin: &mut Merlin, + digest: ::InnerDigest, + ) -> ProofResult<()> { + >::add_digest(merlin, digest) + } +} diff --git a/whir/src/ceno_binding/mod.rs b/whir/src/ceno_binding/mod.rs new file mode 100644 index 000000000..7299cf99a --- /dev/null +++ b/whir/src/ceno_binding/mod.rs @@ -0,0 +1,80 @@ +mod merkle_config; +mod pcs; +pub use ark_crypto_primitives::merkle_tree::Config; +pub use pcs::{DefaultHash, InnerDigestOf, Whir, WhirDefaultSpec, WhirSpec}; + +use ark_ff::FftField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use serde::{Serialize, de::DeserializeOwned}; +use std::fmt::Debug; + +pub use nimue::{ + ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; + +#[derive(Debug, Clone, thiserror::Error)] +pub enum Error { + #[error(transparent)] + ProofError(#[from] nimue::ProofError), + #[error("CommitmentMismatchFromDigest")] + CommitmentMismatchFromDigest, + #[error("InvalidPcsParams")] + InvalidPcsParam, +} + +/// The trait for a non-interactive polynomial commitment scheme. +/// This trait serves as the intermediate step between WHIR and the +/// trait required in Ceno mpcs. Because Ceno and the WHIR implementation +/// in this crate assume different types of transcripts, to connect +/// them we can provide a non-interactive interface from WHIR. +pub trait PolynomialCommitmentScheme: Clone { + type Param: Clone + Debug + Serialize + DeserializeOwned; + type Commitment: Clone + Debug; + type CommitmentWithWitness: Clone + Debug; + type Proof: Clone + CanonicalSerialize + CanonicalDeserialize + Serialize + DeserializeOwned; + type Poly: Clone + Debug + Serialize + DeserializeOwned; + + fn setup(poly_size: usize) -> Self::Param; + + fn commit(pp: &Self::Param, poly: &Self::Poly) -> Result; + + fn batch_commit( + pp: &Self::Param, + polys: &[Self::Poly], + ) -> Result; + + fn open( + pp: &Self::Param, + comm: &Self::CommitmentWithWitness, + point: &[E], + eval: &E, + ) -> Result; + + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::Param, + comm: &Self::CommitmentWithWitness, + point: &[E], + evals: &[E], + ) -> Result; + + fn verify( + vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + eval: &E, + proof: &Self::Proof, + ) -> Result<(), Error>; + + fn simple_batch_verify( + vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + proof: &Self::Proof, + ) -> Result<(), Error>; +} diff --git a/whir/src/ceno_binding/pcs.rs b/whir/src/ceno_binding/pcs.rs new file mode 100644 index 000000000..2c8f8ed80 --- /dev/null +++ b/whir/src/ceno_binding/pcs.rs @@ -0,0 +1,442 @@ +use super::{ + Error, PolynomialCommitmentScheme, + merkle_config::{Blake3ConfigWrapper, WhirMerkleConfigWrapper}, +}; +use crate::{ + crypto::merkle_tree::blake3::{self as mt}, + parameters::{ + FoldType, FoldingFactor, MultivariateParameters, SoundnessType, WhirParameters, + default_max_pow, + }, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::{ + Statement, WhirProof, batch::Witnesses, committer::Committer, parameters::WhirConfig, + prover::Prover, verifier::Verifier, + }, +}; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +pub use nimue::DefaultHash; +use nimue::{ + IOPattern, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::blake3::Blake3PoW; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use std::{ + fmt::{self, Debug, Formatter}, + marker::PhantomData, +}; + +pub trait WhirSpec: Default + std::fmt::Debug + Clone { + type MerkleConfigWrapper: WhirMerkleConfigWrapper; + fn get_parameters( + num_variables: usize, + for_batch: bool, + ) -> WhirParameters, PowOf>; + + fn prepare_whir_config( + num_variables: usize, + for_batch: bool, + ) -> WhirConfig, PowOf> { + let whir_params = Self::get_parameters(num_variables, for_batch); + let mv_params = MultivariateParameters::new(num_variables); + ConfigOf::::new(mv_params, whir_params) + } + + fn prepare_io_pattern(num_variables: usize) -> IOPattern { + let params = Self::prepare_whir_config(num_variables, false); + + let io = IOPattern::::new("🌪️"); + let io = >::commit_statement_to_io_pattern( + io, ¶ms, + ); + >::add_whir_proof_to_io_pattern( + io, ¶ms, + ) + } + + fn prepare_batch_io_pattern(num_variables: usize, batch_size: usize) -> IOPattern { + let params = Self::prepare_whir_config(num_variables, true); + + let io = IOPattern::::new("🌪️"); + let io = >::commit_batch_statement_to_io_pattern( + io, ¶ms, batch_size + ); + + >::add_whir_batch_proof_to_io_pattern( + io, ¶ms, batch_size + ) + } +} + +type MerkleConfigOf = + <>::MerkleConfigWrapper as WhirMerkleConfigWrapper>::MerkleConfig; +type ConfigOf = WhirConfig, PowOf>; + +pub type InnerDigestOf = as Config>::InnerDigest; + +type PowOf = + <>::MerkleConfigWrapper as WhirMerkleConfigWrapper>::PowStrategy; + +#[derive(Debug, Clone, Default)] +pub struct WhirDefaultSpec; + +impl WhirSpec for WhirDefaultSpec { + type MerkleConfigWrapper = Blake3ConfigWrapper; + fn get_parameters( + num_variables: usize, + for_batch: bool, + ) -> WhirParameters, Blake3PoW> { + let mut rng = ChaCha8Rng::from_seed([0u8; 32]); + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + WhirParameters::, Blake3PoW> { + initial_statement: true, + security_level: 100, + pow_bits: default_max_pow(num_variables, 1), + // For batching, the first round folding factor should be set small + // to avoid large leaf nodes in proof + folding_factor: if for_batch { + FoldingFactor::ConstantFromSecondRound(1, 4) + } else { + FoldingFactor::Constant(4) + }, + leaf_hash_params, + two_to_one_params, + soundness_type: SoundnessType::ConjectureList, + fold_optimisation: FoldType::ProverHelps, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct WhirSetupParams { + pub num_variables: usize, + _phantom: PhantomData, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Whir>(PhantomData<(E, Spec)>); + +// Wrapper for WhirProof +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + pub proof: WhirProof, + pub transcript: Vec, +} + +impl Serialize for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let proof = &self.proof.0; + // Create a buffer that implements the `Write` trait + let mut buffer = Vec::new(); + proof.serialize_compressed(&mut buffer).unwrap(); + let proof_size = buffer.len(); + let proof_size_bytes = proof_size.to_le_bytes(); + let mut data = proof_size_bytes.to_vec(); + data.extend_from_slice(&buffer); + data.extend_from_slice(&self.transcript); + serializer.serialize_bytes(&data) + } +} + +impl<'de, MerkleConfig, F> Deserialize<'de> for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let data: Vec = Deserialize::deserialize(deserializer)?; + let proof_size_bytes = &data[0..8]; + let proof_size = u64::from_le_bytes(proof_size_bytes.try_into().unwrap()); + let proof_bytes = &data[8..8 + proof_size as usize]; + let proof = WhirProof::deserialize_compressed(proof_bytes).unwrap(); + let transcript = data[8 + proof_size as usize..].to_vec(); + Ok(WhirProofWrapper { proof, transcript }) + } +} + +impl Debug for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("WhirProofWrapper") + } +} + +#[derive(Clone)] +pub struct CommitmentWithWitness +where + MerkleConfig: Config, +{ + pub commitment: MerkleConfig::InnerDigest, + pub witness: Witnesses, +} + +impl CommitmentWithWitness +where + MerkleConfig: Config, +{ + pub fn ood_answers(&self) -> Vec { + self.witness.ood_answers.clone() + } +} + +impl Debug for CommitmentWithWitness +where + MerkleConfig: Config, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("CommitmentWithWitness") + } +} + +impl> PolynomialCommitmentScheme for Whir +where + E: FftField + Serialize + DeserializeOwned + Debug, + E::BasePrimeField: Serialize + DeserializeOwned + Debug, +{ + type Param = (); + type Commitment = as Config>::InnerDigest; + type CommitmentWithWitness = CommitmentWithWitness>; + type Proof = WhirProofWrapper, E>; + type Poly = CoefficientList; + + fn setup(_poly_size: usize) -> Self::Param {} + + fn commit(_pp: &Self::Param, poly: &Self::Poly) -> Result { + let params = Spec::prepare_whir_config(poly.num_variables(), false); + + // The merlin here is just for satisfying the interface of + // WHIR, which only provides a commit_and_write function. + // It will be abandoned once this function finishes. + let io = Spec::prepare_io_pattern(poly.num_variables()); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params); + let witness = Witnesses::from(Spec::MerkleConfigWrapper::commit_to_merlin( + &committer, + &mut merlin, + poly.clone(), + )?); + + Ok(CommitmentWithWitness { + commitment: witness.merkle_tree.root(), + witness, + }) + } + + fn batch_commit( + _pp: &Self::Param, + polys: &[Self::Poly], + ) -> Result { + if polys.is_empty() { + return Err(Error::InvalidPcsParam); + } + + for i in 1..polys.len() { + if polys[i].num_variables() != polys[0].num_variables() { + return Err(Error::InvalidPcsParam); + } + } + + let params = Spec::prepare_whir_config(polys[0].num_variables(), true); + + // The merlin here is just for satisfying the interface of + // WHIR, which only provides a commit_and_write function. + // It will be abandoned once this function finishes. + let io = Spec::prepare_batch_io_pattern(polys[0].num_variables(), polys.len()); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params); + let witness = + Spec::MerkleConfigWrapper::commit_to_merlin_batch(&committer, &mut merlin, polys)?; + Ok(CommitmentWithWitness { + commitment: witness.merkle_tree.root(), + witness, + }) + } + + fn open( + _pp: &Self::Param, + witness: &Self::CommitmentWithWitness, + point: &[E], + eval: &E, + ) -> Result { + let params = Spec::prepare_whir_config(witness.witness.polys[0].num_variables(), false); + let io = Spec::prepare_io_pattern(witness.witness.polys[0].num_variables()); + let mut merlin = io.to_merlin(); + // In WHIR, the prover writes the commitment to the transcript, then + // the commitment is read from the transcript by the verifier, after + // the transcript is transformed into a arthur transcript. + // Here we repeat whatever the prover does. + // TODO: This is a hack. There should be a better design that does not + // require non-black-box knowledge of the inner working of WHIR. + + >::add_digest_to_merlin( + &mut merlin, + witness.commitment.clone(), + ) + .map_err(Error::ProofError)?; + let ood_answers = witness.ood_answers(); + if !ood_answers.is_empty() { + let mut ood_points = vec![::ZERO; ood_answers.len()]; + merlin + .fill_challenge_scalars(&mut ood_points) + .map_err(Error::ProofError)?; + merlin + .add_scalars(&ood_answers) + .map_err(Error::ProofError)?; + } + // Now the Merlin transcript is ready to pass to the verifier. + + let prover = Prover(params); + let statement = Statement { + points: vec![MultilinearPoint(point.to_vec())], + evaluations: vec![*eval], + }; + + let proof = Spec::MerkleConfigWrapper::prove_with_merlin( + &prover, + &mut merlin, + statement, + witness.witness.clone().into(), + )?; + + Ok(WhirProofWrapper { + proof, + transcript: merlin.transcript().to_vec(), + }) + } + + fn simple_batch_open( + _pp: &Self::Param, + witness: &Self::CommitmentWithWitness, + point: &[E], + evals: &[E], + ) -> Result { + let params = Spec::prepare_whir_config(witness.witness.polys[0].num_variables(), true); + let io = + Spec::prepare_batch_io_pattern(witness.witness.polys[0].num_variables(), evals.len()); + let mut merlin = io.to_merlin(); + // In WHIR, the prover writes the commitment to the transcript, then + // the commitment is read from the transcript by the verifier, after + // the transcript is transformed into a arthur transcript. + // Here we repeat whatever the prover does. + // TODO: This is a hack. There should be a better design that does not + // require non-black-box knowledge of the inner working of WHIR. + + >::add_digest_to_merlin( + &mut merlin, + witness.commitment.clone(), + ) + .map_err(Error::ProofError)?; + let ood_answers = witness.ood_answers(); + if !ood_answers.is_empty() { + let mut ood_points = + vec![::ZERO; ood_answers.len() / evals.len()]; + merlin + .fill_challenge_scalars(&mut ood_points) + .map_err(Error::ProofError)?; + merlin + .add_scalars(&ood_answers) + .map_err(Error::ProofError)?; + } + // Now the Merlin transcript is ready to pass to the verifier. + + let prover = Prover(params); + + let proof = Spec::MerkleConfigWrapper::prove_with_merlin_simple_batch( + &prover, + &mut merlin, + point, + evals, + &witness.witness, + )?; + + Ok(WhirProofWrapper { + proof, + transcript: merlin.transcript().to_vec(), + }) + } + + fn verify( + _vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + eval: &E, + proof: &Self::Proof, + ) -> Result<(), Error> { + let params = Spec::prepare_whir_config(point.len(), false); + let verifier = Verifier::new(params); + let io = Spec::prepare_io_pattern(point.len()); + let mut arthur = io.to_arthur(&proof.transcript); + + let statement = Statement { + points: vec![MultilinearPoint(point.to_vec())], + evaluations: vec![*eval], + }; + + let digest = Spec::MerkleConfigWrapper::verify_with_arthur( + &verifier, + &mut arthur, + &statement, + &proof.proof, + )?; + + if &digest != comm { + return Err(Error::CommitmentMismatchFromDigest); + } + + Ok(()) + } + + fn simple_batch_verify( + _vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + proof: &Self::Proof, + ) -> Result<(), Error> { + let params = Spec::prepare_whir_config(point.len(), true); + let verifier = Verifier::new(params); + let io = Spec::prepare_batch_io_pattern(point.len(), evals.len()); + let mut arthur = io.to_arthur(&proof.transcript); + + let digest = Spec::MerkleConfigWrapper::verify_with_arthur_simple_batch( + &verifier, + &mut arthur, + point, + evals, + &proof.proof, + )?; + + if &digest != comm { + return Err(Error::CommitmentMismatchFromDigest); + } + + Ok(()) + } +} diff --git a/whir/src/cmdline_utils.rs b/whir/src/cmdline_utils.rs new file mode 100644 index 000000000..9572fa328 --- /dev/null +++ b/whir/src/cmdline_utils.rs @@ -0,0 +1,75 @@ +use std::str::FromStr; + +use serde::Serialize; + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum WhirType { + LDT, + PCS, +} + +impl FromStr for WhirType { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "LDT" { + Ok(Self::LDT) + } else if s == "PCS" { + Ok(Self::PCS) + } else { + Err(format!("Invalid field: {}", s)) + } + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum AvailableFields { + Goldilocks1, // Just Goldilocks + Goldilocks2, // Quadratic extension of Goldilocks + Goldilocks3, // Cubic extension of Goldilocks + Field128, // 128-bit prime field + Field192, // 192-bit prime field + Field256, // 256-bit prime field +} + +impl FromStr for AvailableFields { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "Field128" { + Ok(Self::Field128) + } else if s == "Field192" { + Ok(Self::Field192) + } else if s == "Field256" { + Ok(Self::Field256) + } else if s == "Goldilocks1" { + Ok(Self::Goldilocks1) + } else if s == "Goldilocks2" { + Ok(Self::Goldilocks2) + } else if s == "Goldilocks3" { + Ok(Self::Goldilocks3) + } else { + Err(format!("Invalid field: {}", s)) + } + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum AvailableMerkle { + Keccak256, + Blake3, +} + +impl FromStr for AvailableMerkle { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "Keccak" { + Ok(Self::Keccak256) + } else if s == "Blake3" { + Ok(Self::Blake3) + } else { + Err(format!("Invalid hash: {}", s)) + } + } +} diff --git a/whir/src/crypto/fields.rs b/whir/src/crypto/fields.rs new file mode 100644 index 000000000..6fcb2f909 --- /dev/null +++ b/whir/src/crypto/fields.rs @@ -0,0 +1,92 @@ +use ark_ff::{ + Field, Fp2, Fp2Config, Fp3, Fp3Config, Fp64, Fp128, Fp192, Fp256, MontBackend, MontConfig, + MontFp, PrimeField, +}; + +pub trait FieldWithSize { + fn field_size_in_bits() -> usize; +} + +impl FieldWithSize for F +where + F: Field, +{ + fn field_size_in_bits() -> usize { + F::BasePrimeField::MODULUS_BIT_SIZE as usize * F::extension_degree() as usize + } +} + +#[derive(MontConfig)] +#[modulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"] +#[generator = "5"] +pub struct BN254Config; +pub type Field256 = Fp256>; + +#[derive(MontConfig)] +#[modulus = "3801539170989320091464968600173246866371124347557388484609"] +#[generator = "3"] +pub struct FConfig192; +pub type Field192 = Fp192>; + +#[derive(MontConfig)] +#[modulus = "340282366920938463463374557953744961537"] +#[generator = "3"] +pub struct FrConfig128; +pub type Field128 = Fp128>; + +#[derive(MontConfig)] +#[modulus = "18446744069414584321"] +#[generator = "7"] +pub struct FConfig64; +pub type Field64 = Fp64>; + +pub type Field64_2 = Fp2; +pub struct F2Config64; +impl Fp2Config for F2Config64 { + type Fp = Field64; + + const NONRESIDUE: Self::Fp = MontFp!("7"); + + const FROBENIUS_COEFF_FP2_C1: &'static [Self::Fp] = &[ + // Fq(7)**(((q^0) - 1) / 2) + MontFp!("1"), + // Fq(7)**(((q^1) - 1) / 2) + MontFp!("18446744069414584320"), + ]; +} + +pub type Field64_3 = Fp3; +pub struct F3Config64; + +impl Fp3Config for F3Config64 { + type Fp = Field64; + + const NONRESIDUE: Self::Fp = MontFp!("2"); + + const FROBENIUS_COEFF_FP3_C1: &'static [Self::Fp] = &[ + MontFp!("1"), + // Fq(2)^(((q^1) - 1) / 3) + MontFp!("4294967295"), + // Fq(2)^(((q^2) - 1) / 3) + MontFp!("18446744065119617025"), + ]; + + const FROBENIUS_COEFF_FP3_C2: &'static [Self::Fp] = &[ + MontFp!("1"), + // Fq(2)^(((2q^1) - 2) / 3) + MontFp!("18446744065119617025"), + // Fq(2)^(((2q^2) - 2) / 3) + MontFp!("4294967295"), + ]; + + // (q^3 - 1) = 2^32 * T where T = 1461501636310055817916238417282618014431694553085 + const TWO_ADICITY: u32 = 32; + + // 11^T + const QUADRATIC_NONRESIDUE_TO_T: Fp3 = + Fp3::new(MontFp!("5944137876247729999"), MontFp!("0"), MontFp!("0")); + + // T - 1 / 2 + const TRACE_MINUS_ONE_DIV_TWO: &'static [u64] = + &[0x80000002fffffffe, 0x80000002fffffffc, 0x7ffffffe]; +} diff --git a/whir/src/crypto/merkle_tree/blake3.rs b/whir/src/crypto/merkle_tree/blake3.rs new file mode 100644 index 000000000..2260f3984 --- /dev/null +++ b/whir/src/crypto/merkle_tree/blake3.rs @@ -0,0 +1,158 @@ +use std::{borrow::Borrow, marker::PhantomData}; + +use super::{HashCounter, IdentityDigestConverter}; +use crate::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, + sponge::Absorb, +}; +use ark_ff::Field; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use nimue::{ + Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult, +}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; + +#[derive( + Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, +)] +pub struct Blake3Digest([u8; 32]); + +impl AsRef<[u8]> for Blake3Digest { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl From<[u8; 32]> for Blake3Digest { + fn from(value: [u8; 32]) -> Self { + Self(value) + } +} + +impl Absorb for Blake3Digest { + fn to_sponge_bytes(&self, dest: &mut Vec) { + dest.extend_from_slice(&self.0); + } + + fn to_sponge_field_elements(&self, dest: &mut Vec) { + let mut buf = [0; 32]; + buf.copy_from_slice(&self.0); + dest.push(F::from_be_bytes_mod_order(&buf)); + } +} + +pub struct Blake3LeafHash(PhantomData); +pub struct Blake3TwoToOneCRHScheme; + +impl CRHScheme for Blake3LeafHash { + type Input = [F]; + type Output = Blake3Digest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + let mut buf = vec![]; + CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; + + let mut h = blake3::Hasher::new(); + h.update(&buf); + + let mut output = [0; 32]; + output.copy_from_slice(h.finalize().as_bytes()); + HashCounter::add(); + Ok(Blake3Digest(output)) + } +} + +impl TwoToOneCRHScheme for Blake3TwoToOneCRHScheme { + type Input = Blake3Digest; + type Output = Blake3Digest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + let mut h = blake3::Hasher::new(); + h.update(&left_input.borrow().0); + h.update(&right_input.borrow().0); + let mut output = [0; 32]; + output.copy_from_slice(h.finalize().as_bytes()); + HashCounter::add(); + Ok(Blake3Digest(output)) + } + + fn compress>( + parameters: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + ::evaluate(parameters, left_input, right_input) + } +} + +pub type LeafH = Blake3LeafHash; +pub type CompressH = Blake3TwoToOneCRHScheme; + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct MerkleTreeParams(PhantomData); + +impl Config for MerkleTreeParams { + type Leaf = [F]; + + type LeafDigest = as CRHScheme>::Output; + type LeafInnerDigestConverter = IdentityDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub fn default_config( + rng: &mut impl RngCore, +) -> ( + as CRHScheme>::Parameters, + ::Parameters, +) { + as CRHScheme>::setup(rng).unwrap(); + ::setup(rng).unwrap(); + + ((), ()) +} + +impl DigestIOPattern> for IOPattern { + fn add_digest(self, label: &str) -> Self { + self.add_bytes(32, label) + } +} + +impl DigestWriter> for Merlin { + fn add_digest(&mut self, digest: Blake3Digest) -> ProofResult<()> { + self.add_bytes(&digest.0).map_err(ProofError::InvalidIO) + } +} + +impl DigestReader> for Arthur<'_> { + fn read_digest(&mut self) -> ProofResult { + let mut digest = [0; 32]; + self.fill_next_bytes(&mut digest)?; + Ok(Blake3Digest(digest)) + } +} diff --git a/whir/src/crypto/merkle_tree/keccak.rs b/whir/src/crypto/merkle_tree/keccak.rs new file mode 100644 index 000000000..8bdd0af07 --- /dev/null +++ b/whir/src/crypto/merkle_tree/keccak.rs @@ -0,0 +1,158 @@ +use std::{borrow::Borrow, marker::PhantomData}; + +use super::{HashCounter, IdentityDigestConverter}; +use crate::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, + sponge::Absorb, +}; +use ark_ff::Field; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use nimue::{ + Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult, +}; +use rand::RngCore; +use sha3::Digest; + +#[derive( + Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, +)] +pub struct KeccakDigest([u8; 32]); + +impl Absorb for KeccakDigest { + fn to_sponge_bytes(&self, dest: &mut Vec) { + dest.extend_from_slice(&self.0); + } + + fn to_sponge_field_elements(&self, dest: &mut Vec) { + let mut buf = [0; 32]; + buf.copy_from_slice(&self.0); + dest.push(F::from_be_bytes_mod_order(&buf)); + } +} + +impl From<[u8; 32]> for KeccakDigest { + fn from(value: [u8; 32]) -> Self { + KeccakDigest(value) + } +} + +impl AsRef<[u8]> for KeccakDigest { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +pub struct KeccakLeafHash(PhantomData); +pub struct KeccakTwoToOneCRHScheme; + +impl CRHScheme for KeccakLeafHash { + type Input = [F]; + type Output = KeccakDigest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + let mut buf = vec![]; + CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; + + let mut h = sha3::Keccak256::new(); + h.update(&buf); + + let mut output = [0; 32]; + output.copy_from_slice(&h.finalize()[..]); + HashCounter::add(); + Ok(KeccakDigest(output)) + } +} + +impl TwoToOneCRHScheme for KeccakTwoToOneCRHScheme { + type Input = KeccakDigest; + type Output = KeccakDigest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + let mut h = sha3::Keccak256::new(); + h.update(left_input.borrow().0); + h.update(right_input.borrow().0); + let mut output = [0; 32]; + output.copy_from_slice(&h.finalize()[..]); + HashCounter::add(); + Ok(KeccakDigest(output)) + } + + fn compress>( + parameters: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + ::evaluate(parameters, left_input, right_input) + } +} + +pub type LeafH = KeccakLeafHash; +pub type CompressH = KeccakTwoToOneCRHScheme; + +#[derive(Debug, Default, Clone)] +pub struct MerkleTreeParams(PhantomData); + +impl Config for MerkleTreeParams { + type Leaf = [F]; + + type LeafDigest = as CRHScheme>::Output; + type LeafInnerDigestConverter = IdentityDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub fn default_config( + rng: &mut impl RngCore, +) -> ( + as CRHScheme>::Parameters, + ::Parameters, +) { + as CRHScheme>::setup(rng).unwrap(); + ::setup(rng).unwrap(); + + ((), ()) +} + +impl DigestIOPattern> for IOPattern { + fn add_digest(self, label: &str) -> Self { + self.add_bytes(32, label) + } +} + +impl DigestWriter> for Merlin { + fn add_digest(&mut self, digest: KeccakDigest) -> ProofResult<()> { + self.add_bytes(&digest.0).map_err(ProofError::InvalidIO) + } +} + +impl DigestReader> for Arthur<'_> { + fn read_digest(&mut self) -> ProofResult { + let mut digest = [0; 32]; + self.fill_next_bytes(&mut digest)?; + Ok(KeccakDigest(digest)) + } +} diff --git a/whir/src/crypto/merkle_tree/mock.rs b/whir/src/crypto/merkle_tree/mock.rs new file mode 100644 index 000000000..4becc42da --- /dev/null +++ b/whir/src/crypto/merkle_tree/mock.rs @@ -0,0 +1,67 @@ +use std::{borrow::Borrow, marker::PhantomData}; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::{ByteDigestConverter, Config}, +}; +use ark_serialize::CanonicalSerialize; +use rand::RngCore; + +pub struct Mock; + +impl TwoToOneCRHScheme for Mock { + type Input = [u8]; + type Output = Vec; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + _: T, + _: T, + ) -> Result { + Ok(vec![0u8; 32]) + } + + fn compress>( + _: &Self::Parameters, + _: T, + _: T, + ) -> Result { + Ok(vec![0u8; 32]) + } +} + +pub type LeafH = super::LeafIdentityHasher; +pub type CompressH = Mock; + +#[derive(Debug, Default)] +pub struct MerkleTreeParams(PhantomData); + +impl Config for MerkleTreeParams { + type Leaf = F; + + type LeafDigest = as CRHScheme>::Output; + type LeafInnerDigestConverter = ByteDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub fn default_config( + rng: &mut impl RngCore, +) -> ( + as CRHScheme>::Parameters, + ::Parameters, +) { + as CRHScheme>::setup(rng).unwrap(); + { + ::setup(rng).unwrap(); + }; + + ((), ()) +} diff --git a/whir/src/crypto/merkle_tree/mod.rs b/whir/src/crypto/merkle_tree/mod.rs new file mode 100644 index 000000000..092358ff5 --- /dev/null +++ b/whir/src/crypto/merkle_tree/mod.rs @@ -0,0 +1,73 @@ +pub mod blake3; +pub mod keccak; +pub mod mock; + +use std::{borrow::Borrow, marker::PhantomData, sync::atomic::AtomicUsize}; + +use ark_crypto_primitives::{Error, crh::CRHScheme, merkle_tree::DigestConverter}; +use ark_serialize::CanonicalSerialize; +use lazy_static::lazy_static; +use rand::RngCore; + +#[derive(Debug, Default)] +pub struct HashCounter { + counter: AtomicUsize, +} + +lazy_static! { + static ref HASH_COUNTER: HashCounter = HashCounter::default(); +} + +impl HashCounter { + pub(crate) fn add() -> usize { + HASH_COUNTER + .counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + } + + pub fn reset() { + HASH_COUNTER + .counter + .store(0, std::sync::atomic::Ordering::SeqCst) + } + + pub fn get() -> usize { + HASH_COUNTER + .counter + .load(std::sync::atomic::Ordering::SeqCst) + } +} + +#[derive(Debug, Default)] +pub struct LeafIdentityHasher(PhantomData); + +impl CRHScheme for LeafIdentityHasher { + type Input = F; + type Output = Vec; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + let mut buf = vec![]; + CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; + Ok(buf) + } +} + +/// A trivial converter where digest of previous layer's hash is the same as next layer's input. +pub struct IdentityDigestConverter { + _prev_layer_digest: T, +} + +impl DigestConverter for IdentityDigestConverter { + type TargetType = T; + fn convert(item: T) -> Result { + Ok(item) + } +} diff --git a/whir/src/crypto/mod.rs b/whir/src/crypto/mod.rs new file mode 100644 index 000000000..039b38f23 --- /dev/null +++ b/whir/src/crypto/mod.rs @@ -0,0 +1,2 @@ +pub mod fields; +pub mod merkle_tree; diff --git a/whir/src/domain.rs b/whir/src/domain.rs new file mode 100644 index 000000000..3c2bfc9b1 --- /dev/null +++ b/whir/src/domain.rs @@ -0,0 +1,144 @@ +use ark_ff::FftField; +use ark_poly::{ + EvaluationDomain, GeneralEvaluationDomain, MixedRadixEvaluationDomain, Radix2EvaluationDomain, +}; + +#[derive(Debug, Clone)] +pub struct Domain +where + F: FftField, +{ + pub base_domain: Option>, // The domain (in the base + // field) for the initial FFT + pub backing_domain: GeneralEvaluationDomain, +} + +impl Domain +where + F: FftField, +{ + pub fn new(degree: usize, log_rho_inv: usize) -> Option { + let size = degree * (1 << log_rho_inv); + let base_domain = GeneralEvaluationDomain::new(size)?; + let backing_domain = Self::to_extension_domain(&base_domain); + + Some(Self { + backing_domain, + base_domain: Some(base_domain), + }) + } + + // returns the size of the domain after folding folding_factor many times. + // + // This asserts that the domain size is divisible by 1 << folding_factor + pub fn folded_size(&self, folding_factor: usize) -> usize { + assert!(self.backing_domain.size() % (1 << folding_factor) == 0); + self.backing_domain.size() / (1 << folding_factor) + } + + pub fn size(&self) -> usize { + self.backing_domain.size() + } + + pub fn scale(&self, power: usize) -> Self { + Self { + backing_domain: self.scale_generator_by(power), + base_domain: None, // Set to zero because we only care for the initial + } + } + + fn to_extension_domain( + domain: &GeneralEvaluationDomain, + ) -> GeneralEvaluationDomain { + let group_gen = F::from_base_prime_field(domain.group_gen()); + let group_gen_inv = F::from_base_prime_field(domain.group_gen_inv()); + let size = domain.size() as u64; + let log_size_of_group = domain.log_size_of_group() as u32; + let size_as_field_element = F::from_base_prime_field(domain.size_as_field_element()); + let size_inv = F::from_base_prime_field(domain.size_inv()); + let offset = F::from_base_prime_field(domain.coset_offset()); + let offset_inv = F::from_base_prime_field(domain.coset_offset_inv()); + let offset_pow_size = F::from_base_prime_field(domain.coset_offset_pow_size()); + match domain { + GeneralEvaluationDomain::Radix2(_) => { + GeneralEvaluationDomain::Radix2(Radix2EvaluationDomain { + size, + log_size_of_group, + size_as_field_element, + size_inv, + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + GeneralEvaluationDomain::MixedRadix(_) => { + GeneralEvaluationDomain::MixedRadix(MixedRadixEvaluationDomain { + size, + log_size_of_group, + size_as_field_element, + size_inv, + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + } + } + + // Takes the underlying backing_domain = , and computes the new domain + // (note this will have size |L| / power) + fn scale_generator_by(&self, power: usize) -> GeneralEvaluationDomain { + let starting_size = self.size(); + assert_eq!(starting_size % power, 0); + let new_size = starting_size / power; + let log_size_of_group = new_size.trailing_zeros(); + let size_as_field_element = F::from(new_size as u64); + + match self.backing_domain { + GeneralEvaluationDomain::Radix2(r2) => { + let group_gen = r2.group_gen.pow([power as u64]); + let group_gen_inv = group_gen.inverse().unwrap(); + + let offset = r2.offset.pow([power as u64]); + let offset_inv = r2.offset_inv.pow([power as u64]); + let offset_pow_size = offset.pow([new_size as u64]); + + GeneralEvaluationDomain::Radix2(Radix2EvaluationDomain { + size: new_size as u64, + log_size_of_group, + size_as_field_element, + size_inv: size_as_field_element.inverse().unwrap(), + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + GeneralEvaluationDomain::MixedRadix(mr) => { + let group_gen = mr.group_gen.pow([power as u64]); + let group_gen_inv = mr.group_gen_inv.pow([power as u64]); + + let offset = mr.offset.pow([power as u64]); + let offset_inv = mr.offset_inv.pow([power as u64]); + let offset_pow_size = offset.pow([new_size as u64]); + + GeneralEvaluationDomain::MixedRadix(MixedRadixEvaluationDomain { + size: new_size as u64, + log_size_of_group, + size_as_field_element, + size_inv: size_as_field_element.inverse().unwrap(), + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + } + } +} diff --git a/whir/src/fs_utils.rs b/whir/src/fs_utils.rs new file mode 100644 index 000000000..9d9180415 --- /dev/null +++ b/whir/src/fs_utils.rs @@ -0,0 +1,38 @@ +use ark_ff::Field; +use nimue::plugins::ark::FieldIOPattern; +use nimue_pow::PoWIOPattern; +pub trait OODIOPattern { + fn add_ood(self, num_samples: usize) -> Self; +} + +impl OODIOPattern for IOPattern +where + F: Field, + IOPattern: FieldIOPattern, +{ + fn add_ood(self, num_samples: usize) -> Self { + if num_samples > 0 { + self.challenge_scalars(num_samples, "ood_query") + .add_scalars(num_samples, "ood_ans") + } else { + self + } + } +} + +pub trait WhirPoWIOPattern { + fn pow(self, bits: f64) -> Self; +} + +impl WhirPoWIOPattern for IOPattern +where + IOPattern: PoWIOPattern, +{ + fn pow(self, bits: f64) -> Self { + if bits > 0. { + self.challenge_pow("pow_queries") + } else { + self + } + } +} diff --git a/whir/src/lib.rs b/whir/src/lib.rs new file mode 100644 index 000000000..fe4ccb09a --- /dev/null +++ b/whir/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(dead_code)] +#[cfg(feature = "ceno")] +pub mod ceno_binding; // Connect whir with ceno +pub mod cmdline_utils; +pub mod crypto; // Crypto utils +pub mod domain; // Domain that we are evaluating over +pub mod fs_utils; +pub mod ntt; +pub mod parameters; +pub mod poly_utils; // Utils for polynomials +pub mod sumcheck; // Sumcheck specialised +pub mod utils; // Utils in general +pub mod whir; // The real prover diff --git a/whir/src/ntt/matrix.rs b/whir/src/ntt/matrix.rs new file mode 100644 index 000000000..0f7eb8780 --- /dev/null +++ b/whir/src/ntt/matrix.rs @@ -0,0 +1,187 @@ +//! Minimal matrix class that supports strided access. +//! This abstracts over the unsafe pointer arithmetic required for transpose-like algorithms. + +#![allow(unsafe_code)] + +use std::{ + marker::PhantomData, + ops::{Index, IndexMut}, + ptr, slice, +}; + +/// Mutable reference to a matrix. +/// +/// The invariant this data structure maintains is that `data` has lifetime +/// `'a` and points to a collection of `rows` rowws, at intervals `row_stride`, +/// each of length `cols`. +pub struct MatrixMut<'a, T> { + data: *mut T, + rows: usize, + cols: usize, + row_stride: usize, + _lifetime: PhantomData<&'a mut T>, +} + +unsafe impl Send for MatrixMut<'_, T> {} + +unsafe impl Sync for MatrixMut<'_, T> {} + +impl<'a, T> MatrixMut<'a, T> { + /// creates a MatrixMut from `slice`, where slice is the concatenations of `rows` rows, each consisting of `cols` many entries. + pub fn from_mut_slice(slice: &'a mut [T], rows: usize, cols: usize) -> Self { + assert_eq!(slice.len(), rows * cols); + // Safety: The input slice is valid for the lifetime `'a` and has + // `rows` contiguous rows of length `cols`. + Self { + data: slice.as_mut_ptr(), + rows, + cols, + row_stride: cols, + _lifetime: PhantomData, + } + } + + /// returns the number of rows + pub fn rows(&self) -> usize { + self.rows + } + + /// returns the number of columns + pub fn cols(&self) -> usize { + self.cols + } + + /// checks whether the matrix is a square matrix + pub fn is_square(&self) -> bool { + self.rows == self.cols + } + + /// returns a mutable reference to the `row`'th row of the MatrixMut + pub fn row(&mut self, row: usize) -> &mut [T] { + assert!(row < self.rows); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride` + // there is valid data of length `self.cols`. + unsafe { slice::from_raw_parts_mut(self.data.add(row * self.row_stride), self.cols) } + } + + /// Split the matrix into two vertically at the `row`'th row (meaning that in the returned pair (A,B), the matrix A has `row` rows). + /// + /// [A] + /// [ ] = self + /// [B] + pub fn split_vertical(self, row: usize) -> (Self, Self) { + assert!(row <= self.rows); + ( + Self { + data: self.data, + rows: row, + cols: self.cols, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + Self { + data: unsafe { self.data.add(row * self.row_stride) }, + rows: self.rows - row, + cols: self.cols, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + ) + } + + /// Split the matrix into two horizontally at the `col`th column (meaning that in the returned pair (A,B), the matrix A has `col` columns). + /// + /// [A B] = self + pub fn split_horizontal(self, col: usize) -> (Self, Self) { + assert!(col <= self.cols); + ( + // Safety: This reduces the number of cols, keeping all else the same. + Self { + data: self.data, + rows: self.rows, + cols: col, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + // Safety: This reduces the number of cols and offsets and, keeping all else the same. + Self { + data: unsafe { self.data.add(col) }, + rows: self.rows, + cols: self.cols - col, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + ) + } + + /// Split the matrix into four quadrants at the indicated `row` and `col` (meaning that in the returned 4-tuple (A,B,C,D), the matrix A is a `row`x`col` matrix) + /// + /// self = [A B] + /// [C D] + pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + let (u, l) = self.split_vertical(row); // split into upper and lower parts + let (a, b) = u.split_horizontal(col); + let (c, d) = l.split_horizontal(col); + (a, b, c, d) + } + + /// Swap two elements `a` and `b` in the matrix. + /// Each of `a`, `b` is given as (row,column)-pair. + /// If the given coordinates are out-of-bounds, the behaviour is undefined. + pub unsafe fn swap(&mut self, a: (usize, usize), b: (usize, usize)) { + if a != b { + unsafe { + let a = self.ptr_at_mut(a.0, a.1); + let b = self.ptr_at_mut(b.0, b.1); + ptr::swap_nonoverlapping(a, b, 1) + } + } + } + + /// returns an immutable pointer to the element at (`row`, `col`). This performs no bounds checking and provining indices out-of-bounds is UB. + unsafe fn ptr_at(&self, row: usize, col: usize) -> *const T { + // Safe to call under the following assertion (checked by caller) + // assert!(row < self.rows); + // assert!(col < self.cols); + + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + self.data.add(row * self.row_stride + col) + } + + /// returns a mutable pointer to the element at (`row`, `col`). This performs no bounds checking and provining indices out-of-bounds is UB. + unsafe fn ptr_at_mut(&mut self, row: usize, col: usize) -> *mut T { + // Safe to call under the following assertion (checked by caller) + // + // assert!(row < self.rows); + // assert!(col < self.cols); + + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + self.data.add(row * self.row_stride + col) + } +} + +// Use MatrixMut::ptr_at and MatrixMut::ptr_at_mut to implement Index and IndexMut. The latter are not unsafe, since they contain bounds-checks. + +impl Index<(usize, usize)> for MatrixMut<'_, T> { + type Output = T; + + fn index(&self, (row, col): (usize, usize)) -> &T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + unsafe { &*self.ptr_at(row, col) } + } +} + +impl IndexMut<(usize, usize)> for MatrixMut<'_, T> { + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + unsafe { &mut *self.ptr_at_mut(row, col) } + } +} diff --git a/whir/src/ntt/mod.rs b/whir/src/ntt/mod.rs new file mode 100644 index 000000000..f2386c38a --- /dev/null +++ b/whir/src/ntt/mod.rs @@ -0,0 +1,61 @@ +//! NTT and related algorithms. + +mod matrix; +mod ntt_impl; +mod transpose; +mod utils; +mod wavelet; + +use self::matrix::MatrixMut; +use ark_ff::FftField; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub use self::{ + ntt_impl::{intt, intt_batch, ntt, ntt_batch}, + transpose::{transpose, transpose_bench_allocate, transpose_test}, + wavelet::wavelet_transform, +}; + +/// RS encode at a rate 1/`expansion`. +pub fn expand_from_coeff(coeffs: &[F], expansion: usize) -> Vec { + let engine = ntt_impl::NttEngine::::new_from_cache(); + let expanded_size = coeffs.len() * expansion; + let mut result = Vec::with_capacity(expanded_size); + // Note: We can also zero-extend the coefficients and do a larger NTT. + // But this is more efficient. + + // Do coset NTT. + let root = engine.root(expanded_size); + result.extend_from_slice(coeffs); + #[cfg(not(feature = "parallel"))] + for i in 1..expansion { + let root = root.pow([i as u64]); + let mut offset = F::ONE; + result.extend(coeffs.iter().map(|x| { + let val = *x * offset; + offset *= root; + val + })); + } + #[cfg(feature = "parallel")] + result.par_extend((1..expansion).into_par_iter().flat_map(|i| { + let root_i = root.pow([i as u64]); + coeffs + .par_iter() + .enumerate() + .map_with(F::ZERO, move |root_j, (j, coeff)| { + if root_j.is_zero() { + *root_j = root_i.pow([j as u64]); + } else { + *root_j *= root_i; + } + *coeff * *root_j + }) + })); + + ntt_batch(&mut result, coeffs.len()); + transpose(&mut result, expansion, coeffs.len()); + result +} diff --git a/whir/src/ntt/ntt_impl.rs b/whir/src/ntt/ntt_impl.rs new file mode 100644 index 000000000..ceb482f81 --- /dev/null +++ b/whir/src/ntt/ntt_impl.rs @@ -0,0 +1,408 @@ +//! Number-theoretic transforms (NTTs) over fields with high two-adicity. +//! +//! Implements the √N Cooley-Tukey six-step algorithm to achieve parallelism with good locality. +//! A global cache is used for twiddle factors. + +use super::{ + transpose, + utils::{lcm, sqrt_factor, workload_size}, +}; +use ark_ff::{FftField, Field}; +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard}, +}; + +#[cfg(feature = "parallel")] +use {rayon::prelude::*, std::cmp::max}; + +/// Global cache for NTT engines, indexed by field. +// TODO: Skip `LazyLock` when `HashMap::with_hasher` becomes const. +// see https://github.com/rust-lang/rust/issues/102575 +static ENGINE_CACHE: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Enginge for computing NTTs over arbitrary fields. +/// Assumes the field has large two-adicity. +pub struct NttEngine { + order: usize, // order of omega_orger + omega_order: F, // primitive order'th root. + + // Roots of small order (zero if unavailable). The naming convention is that omega_foo has order foo. + half_omega_3_1_plus_2: F, // ½(ω₃ + ω₃²) + half_omega_3_1_min_2: F, // ½(ω₃ - ω₃²) + omega_4_1: F, + omega_8_1: F, + omega_8_3: F, + omega_16_1: F, + omega_16_3: F, + omega_16_9: F, + + // Root lookup table (extended on demand) + roots: RwLock>, +} + +/// Compute the NTT of a slice of field elements using a cached engine. +pub fn ntt(values: &mut [F]) { + NttEngine::::new_from_cache().ntt(values); +} + +/// Compute the many NTTs of size `size` using a cached engine. +pub fn ntt_batch(values: &mut [F], size: usize) { + NttEngine::::new_from_cache().ntt_batch(values, size); +} + +/// Compute the inverse NTT of a slice of field element without the 1/n scaling factor, using a cached engine. +pub fn intt(values: &mut [F]) { + NttEngine::::new_from_cache().intt(values); +} + +/// Compute the inverse NTT of multiple slice of field elements, each of size `size`, without the 1/n scaling factor and using a cached engine. +pub fn intt_batch(values: &mut [F], size: usize) { + NttEngine::::new_from_cache().intt_batch(values, size); +} + +impl NttEngine { + /// Get or create a cached engine for the field `F`. + pub fn new_from_cache() -> Arc { + let mut cache = ENGINE_CACHE.lock().unwrap(); + let type_id = TypeId::of::(); + if let Some(engine) = cache.get(&type_id) { + engine.clone().downcast::>().unwrap() + } else { + let engine = Arc::new(NttEngine::new_from_fftfield()); + cache.insert(type_id, engine.clone()); + engine + } + } + + /// Construct a new engine from the field's `FftField` trait. + fn new_from_fftfield() -> Self { + // TODO: Support SMALL_SUBGROUP + if F::TWO_ADICITY <= 63 { + Self::new(1 << F::TWO_ADICITY, F::TWO_ADIC_ROOT_OF_UNITY) + } else { + let mut generator = F::TWO_ADIC_ROOT_OF_UNITY; + for _ in 0..(F::TWO_ADICITY - 63) { + generator = generator.square(); + } + Self::new(1 << 63, generator) + } + } +} + +/// Creates a new NttEngine. `omega_order` must be a primitive root of unity of even order `omega`. +impl NttEngine { + pub fn new(order: usize, omega_order: F) -> Self { + assert!(order.trailing_zeros() > 0, "Order must be a multiple of 2."); + // TODO: Assert that omega factors into 2s and 3s. + assert_eq!(omega_order.pow([order as u64]), F::ONE); + assert_ne!(omega_order.pow([order as u64 / 2]), F::ONE); + let mut res = NttEngine { + order, + omega_order, + half_omega_3_1_plus_2: F::ZERO, + half_omega_3_1_min_2: F::ZERO, + omega_4_1: F::ZERO, + omega_8_1: F::ZERO, + omega_8_3: F::ZERO, + omega_16_1: F::ZERO, + omega_16_3: F::ZERO, + omega_16_9: F::ZERO, + roots: RwLock::new(Vec::new()), + }; + if order % 3 == 0 { + let omega_3_1 = res.root(3); + let omega_3_2 = omega_3_1 * omega_3_1; + // Note: char F cannot be 2 and so division by 2 works, because primitive roots of unity with even order exist. + res.half_omega_3_1_min_2 = (omega_3_1 - omega_3_2) / F::from(2u64); + res.half_omega_3_1_plus_2 = (omega_3_1 + omega_3_2) / F::from(2u64); + } + if order % 4 == 0 { + res.omega_4_1 = res.root(4); + } + if order % 8 == 0 { + res.omega_8_1 = res.root(8); + res.omega_8_3 = res.omega_8_1.pow([3]); + } + if order % 16 == 0 { + res.omega_16_1 = res.root(16); + res.omega_16_3 = res.omega_16_1.pow([3]); + res.omega_16_9 = res.omega_16_1.pow([9]); + } + res + } + + pub fn ntt(&self, values: &mut [F]) { + self.ntt_batch(values, values.len()) + } + + pub fn ntt_batch(&self, values: &mut [F], size: usize) { + assert!(values.len() % size == 0); + let roots = self.roots_table(size); + self.ntt_dispatch(values, &roots, size); + } + + /// Inverse NTT. Does not aply 1/n scaling factor. + pub fn intt(&self, values: &mut [F]) { + values[1..].reverse(); + self.ntt(values); + } + + /// Inverse batch NTT. Does not aply 1/n scaling factor. + pub fn intt_batch(&self, values: &mut [F], size: usize) { + assert!(values.len() % size == 0); + + #[cfg(not(feature = "parallel"))] + values.chunks_exact_mut(size).for_each(|values| { + values[1..].reverse(); + }); + + #[cfg(feature = "parallel")] + values.par_chunks_exact_mut(size).for_each(|values| { + values[1..].reverse(); + }); + + self.ntt_batch(values, size); + } + + pub fn root(&self, order: usize) -> F { + assert!( + self.order % order == 0, + "Subgroup of requested order does not exist." + ); + self.omega_order.pow([(self.order / order) as u64]) + } + + /// Returns a cached table of roots of unity of the given order. + fn roots_table(&self, order: usize) -> RwLockReadGuard> { + // Precompute more roots of unity if requested. + let roots = self.roots.read().unwrap(); + if roots.is_empty() || roots.len() % order != 0 { + // Obtain write lock to update the cache. + drop(roots); + let mut roots = self.roots.write().unwrap(); + // Race condition: check if another thread updated the cache. + if roots.is_empty() || roots.len() % order != 0 { + // Compute minimal size to support all sizes seen so far. + // TODO: Do we really need all of these? Can we leverage omege_2 = -1? + let size = if roots.is_empty() { + order + } else { + lcm(roots.len(), order) + }; + roots.clear(); + roots.reserve_exact(size); + + // Compute powers of roots of unity. + let root = self.root(size); + #[cfg(not(feature = "parallel"))] + { + let mut root_i = F::ONE; + for _ in 0..size { + roots.push(root_i); + root_i *= root; + } + } + #[cfg(feature = "parallel")] + roots.par_extend((0..size).into_par_iter().map_with(F::ZERO, |root_i, i| { + if root_i.is_zero() { + *root_i = root.pow([i as u64]); + } else { + *root_i *= root; + } + *root_i + })); + } + // Back to read lock. + drop(roots); + self.roots.read().unwrap() + } else { + roots + } + } + + /// Compute NTTs in place by splititng into two factors. + /// Recurses using the sqrt(N) Cooley-Tukey Six step NTT algorithm. + fn ntt_recurse(&self, values: &mut [F], roots: &[F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + let n1 = sqrt_factor(size); + let n2 = size / n1; + + transpose(values, n1, n2); + self.ntt_dispatch(values, roots, n1); + transpose(values, n2, n1); + // TODO: When (n1, n2) are coprime we can use the + // Good-Thomas NTT algorithm and avoid the twiddle loop. + Self::apply_twiddles(values, roots, n1, n2); + self.ntt_dispatch(values, roots, n2); + transpose(values, n1, n2); + } + + #[cfg(not(feature = "parallel"))] + fn apply_twiddles(&self, values: &mut [F], roots: &[F], rows: usize, cols: usize) { + debug_assert_eq!(values.len() % (rows * cols), 0); + let step = roots.len() / (rows * cols); + for values in values.chunks_exact_mut(rows * cols) { + for (i, row) in values.chunks_exact_mut(cols).enumerate().skip(1) { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + } + } + } + + #[cfg(feature = "parallel")] + fn apply_twiddles(values: &mut [F], roots: &[F], rows: usize, cols: usize) { + debug_assert_eq!(values.len() % (rows * cols), 0); + if values.len() > workload_size::() { + let size = rows * cols; + if values.len() != size { + let workload_size = size * max(1, workload_size::() / size); + values.par_chunks_mut(workload_size).for_each(|values| { + Self::apply_twiddles(values, roots, rows, cols); + }); + } else { + let step = roots.len() / (rows * cols); + values + .par_chunks_exact_mut(cols) + .enumerate() + .skip(1) + .for_each(|(i, row)| { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + }); + } + } else { + let step = roots.len() / (rows * cols); + for values in values.chunks_exact_mut(rows * cols) { + for (i, row) in values.chunks_exact_mut(cols).enumerate().skip(1) { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + } + } + } + } + + fn ntt_dispatch(&self, values: &mut [F], roots: &[F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + debug_assert_eq!(roots.len() % size, 0); + #[cfg(feature = "parallel")] + if values.len() > workload_size::() && values.len() != size { + // Multiple NTTs, compute in parallel. + // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. + let workload_size = size * max(1, workload_size::() / size); + return values.par_chunks_mut(workload_size).for_each(|values| { + self.ntt_dispatch(values, roots, size); + }); + } + match size { + 0 | 1 => {} + 2 => { + for v in values.chunks_exact_mut(2) { + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + } + } + 3 => { + for v in values.chunks_exact_mut(3) { + // Rader NTT to reduce 3 to 2. + let v0 = v[0]; + (v[1], v[2]) = (v[1] + v[2], v[1] - v[2]); + v[0] += v[1]; + v[1] *= self.half_omega_3_1_plus_2; // ½(ω₃ + ω₃²) + v[2] *= self.half_omega_3_1_min_2; // ½(ω₃ - ω₃²) + v[1] += v0; + (v[1], v[2]) = (v[1] + v[2], v[1] - v[2]); + } + } + 4 => { + for v in values.chunks_exact_mut(4) { + (v[0], v[2]) = (v[0] + v[2], v[0] - v[2]); + (v[1], v[3]) = (v[1] + v[3], v[1] - v[3]); + v[3] *= self.omega_4_1; + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + (v[2], v[3]) = (v[2] + v[3], v[2] - v[3]); + (v[1], v[2]) = (v[2], v[1]); + } + } + 8 => { + for v in values.chunks_exact_mut(8) { + // Cooley-Tukey with v as 2x4 matrix. + (v[0], v[4]) = (v[0] + v[4], v[0] - v[4]); + (v[1], v[5]) = (v[1] + v[5], v[1] - v[5]); + (v[2], v[6]) = (v[2] + v[6], v[2] - v[6]); + (v[3], v[7]) = (v[3] + v[7], v[3] - v[7]); + v[5] *= self.omega_8_1; + v[6] *= self.omega_4_1; // == omega_8_2 + v[7] *= self.omega_8_3; + (v[0], v[2]) = (v[0] + v[2], v[0] - v[2]); + (v[1], v[3]) = (v[1] + v[3], v[1] - v[3]); + v[3] *= self.omega_4_1; + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + (v[2], v[3]) = (v[2] + v[3], v[2] - v[3]); + (v[4], v[6]) = (v[4] + v[6], v[4] - v[6]); + (v[5], v[7]) = (v[5] + v[7], v[5] - v[7]); + v[7] *= self.omega_4_1; + (v[4], v[5]) = (v[4] + v[5], v[4] - v[5]); + (v[6], v[7]) = (v[6] + v[7], v[6] - v[7]); + (v[1], v[4]) = (v[4], v[1]); + (v[3], v[6]) = (v[6], v[3]); + } + } + 16 => { + for v in values.chunks_exact_mut(16) { + // Cooley-Tukey with v as 4x4 matrix. + for i in 0..4 { + let v = &mut v[i..]; + (v[0], v[8]) = (v[0] + v[8], v[0] - v[8]); + (v[4], v[12]) = (v[4] + v[12], v[4] - v[12]); + v[12] *= self.omega_4_1; + (v[0], v[4]) = (v[0] + v[4], v[0] - v[4]); + (v[8], v[12]) = (v[8] + v[12], v[8] - v[12]); + (v[4], v[8]) = (v[8], v[4]); + } + v[5] *= self.omega_16_1; + v[6] *= self.omega_8_1; + v[7] *= self.omega_16_3; + v[9] *= self.omega_8_1; + v[10] *= self.omega_4_1; + v[11] *= self.omega_8_3; + v[13] *= self.omega_16_3; + v[14] *= self.omega_8_3; + v[15] *= self.omega_16_9; + for i in 0..4 { + let v = &mut v[i * 4..]; + (v[0], v[2]) = (v[0] + v[2], v[0] - v[2]); + (v[1], v[3]) = (v[1] + v[3], v[1] - v[3]); + v[3] *= self.omega_4_1; + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + (v[2], v[3]) = (v[2] + v[3], v[2] - v[3]); + (v[1], v[2]) = (v[2], v[1]); + } + (v[1], v[4]) = (v[4], v[1]); + (v[2], v[8]) = (v[8], v[2]); + (v[3], v[12]) = (v[12], v[3]); + (v[6], v[9]) = (v[9], v[6]); + (v[7], v[13]) = (v[13], v[7]); + (v[11], v[14]) = (v[14], v[11]); + } + } + size => self.ntt_recurse(values, roots, size), + } + } +} diff --git a/whir/src/ntt/transpose.rs b/whir/src/ntt/transpose.rs new file mode 100644 index 000000000..b07ad7497 --- /dev/null +++ b/whir/src/ntt/transpose.rs @@ -0,0 +1,550 @@ +use super::{super::utils::is_power_of_two, MatrixMut, utils::workload_size}; +use std::mem::swap; + +use ark_std::{end_timer, start_timer}; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +#[cfg(feature = "parallel")] +use rayon::join; + +// NOTE: The assumption that rows and cols are a power of two are actually only relevant for the square matrix case. +// (This is because the algorithm recurses into 4 sub-matrices of half dimension; we assume those to be square matrices as well, which only works for powers of two). + +/// Transpose a matrix in-place. +/// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols. +/// This algorithm assumes that both rows and cols are powers of two. +pub fn transpose(matrix: &mut [F], rows: usize, cols: usize) { + debug_assert_eq!(matrix.len() % (rows * cols), 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + debug_assert!(is_power_of_two(rows)); + debug_assert!(is_power_of_two(cols)); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let mut scratch = Vec::with_capacity(rows * cols); + #[allow(clippy::uninit_vec)] + unsafe { + scratch.set_len(rows * cols); + } + for matrix in matrix.chunks_exact_mut(rows * cols) { + scratch.copy_from_slice(matrix); + let src = MatrixMut::from_mut_slice(scratch.as_mut_slice(), rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + transpose_copy(src, dst); + } + } +} + +pub fn transpose_bench_allocate( + matrix: &mut [F], + rows: usize, + cols: usize, +) { + debug_assert_eq!(matrix.len() % (rows * cols), 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + debug_assert!(is_power_of_two(rows)); + debug_assert!(is_power_of_two(cols)); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let allocate_timer = start_timer!(|| "Allocate scratch."); + let mut scratch = Vec::with_capacity(rows * cols); + #[allow(clippy::uninit_vec)] + unsafe { + scratch.set_len(rows * cols); + } + end_timer!(allocate_timer); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let copy_timer = start_timer!(|| "Copy from slice."); + scratch.copy_from_slice(matrix); + end_timer!(copy_timer); + let src = MatrixMut::from_mut_slice(scratch.as_mut_slice(), rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + let transpose_copy_timer = start_timer!(|| "Transpose Copy."); + transpose_copy(src, dst); + end_timer!(transpose_copy_timer); + } + } +} + +pub fn transpose_test( + matrix: &mut [F], + rows: usize, + cols: usize, + buffer: &mut [F], +) { + debug_assert_eq!(matrix.len() % (rows * cols), 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + debug_assert!(is_power_of_two(rows)); + debug_assert!(is_power_of_two(cols)); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + let buffer = &mut buffer[0..rows * cols]; + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let transpose_timer = start_timer!(|| "Transpose."); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let copy_timer = start_timer!(|| "Copy from slice."); + // buffer.copy_from_slice(matrix); + buffer + .par_iter_mut() + .zip(matrix.par_iter_mut()) + .for_each(|(dst, src)| { + *dst = *src; + }); + end_timer!(copy_timer); + let transform_timer = start_timer!(|| "From mut slice."); + let src = MatrixMut::from_mut_slice(buffer, rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + end_timer!(transform_timer); + let transpose_copy_timer = start_timer!(|| "Transpose copy."); + transpose_copy(src, dst); + end_timer!(transpose_copy_timer); + } + end_timer!(transpose_timer); + } +} + +// The following function have both a parallel and a non-parallel implementation. +// We fuly split those in a parallel and a non-parallel functions (rather than using #[cfg] within a single function) +// and have main entry point fun that just calls the appropriate version (either fun_parallel or fun_not_parallel). +// The sole reason is that this simplifies unit tests: We otherwise would need to build twice to cover both cases. +// For effiency, we assume the compiler inlines away the extra "indirection" that we add to the entry point function. + +// NOTE: We could lift the Send constraints on non-parallel build. + +fn transpose_copy(src: MatrixMut, dst: MatrixMut) { + #[cfg(not(feature = "parallel"))] + transpose_copy_not_parallel(src, dst); + #[cfg(feature = "parallel")] + transpose_copy_parallel(src, dst); +} + +/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible. +#[cfg(feature = "parallel")] +fn transpose_copy_parallel( + src: MatrixMut<'_, F>, + mut dst: MatrixMut<'_, F>, +) { + assert_eq!(src.rows(), dst.cols()); + assert_eq!(src.cols(), dst.rows()); + if src.rows() * src.cols() > workload_size::() { + // Split along longest axis and recurse. + // This results in a cache-oblivious algorithm. + let ((a, b), (x, y)) = if src.rows() > src.cols() { + let n = src.rows() / 2; + (src.split_vertical(n), dst.split_horizontal(n)) + } else { + let n = src.cols() / 2; + (src.split_horizontal(n), dst.split_vertical(n)) + }; + join( + || transpose_copy_parallel(a, x), + || transpose_copy_parallel(b, y), + ); + } else { + for i in 0..src.rows() { + for j in 0..src.cols() { + dst[(j, i)] = src[(i, j)]; + } + } + } +} + +/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible. +/// This is the non-parallel version +fn transpose_copy_not_parallel(src: MatrixMut<'_, F>, mut dst: MatrixMut<'_, F>) { + assert_eq!(src.rows(), dst.cols()); + assert_eq!(src.cols(), dst.rows()); + if src.rows() * src.cols() > workload_size::() { + // Split along longest axis and recurse. + // This results in a cache-oblivious algorithm. + let ((a, b), (x, y)) = if src.rows() > src.cols() { + let n = src.rows() / 2; + (src.split_vertical(n), dst.split_horizontal(n)) + } else { + let n = src.cols() / 2; + (src.split_horizontal(n), dst.split_vertical(n)) + }; + transpose_copy_not_parallel(a, x); + transpose_copy_not_parallel(b, y); + } else { + for i in 0..src.rows() { + for j in 0..src.cols() { + dst[(j, i)] = src[(i, j)]; + } + } + } +} + +/// Transpose a square matrix in-place. Asserts that the size of the matrix is a power of two. +fn transpose_square(m: MatrixMut) { + #[cfg(feature = "parallel")] + transpose_square_parallel(m); + #[cfg(not(feature = "parallel"))] + transpose_square_non_parallel(m); +} + +/// Transpose a square matrix in-place. Asserts that the size of the matrix is a power of two. +/// This is the parallel version. +#[cfg(feature = "parallel")] +fn transpose_square_parallel(mut m: MatrixMut) { + debug_assert!(m.is_square()); + debug_assert!(m.rows().is_power_of_two()); + let size = m.rows(); + if size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (a, b, c, d) = m.split_quadrants(n, n); + + join( + || transpose_square_swap_parallel(b, c), + || { + join( + || transpose_square_parallel(a), + || transpose_square_parallel(d), + ) + }, + ); + } else { + for i in 0..size { + for j in (i + 1)..size { + // unsafe needed due to lack of bounds-check by swap. We are guaranteed that (i,j) and (j,i) are within the bounds. + unsafe { + m.swap((i, j), (j, i)); + } + } + } + } +} + +/// Transpose a square matrix in-place. Asserts that the size of the matrix is a power of two. +/// This is the non-parallel version. +fn transpose_square_non_parallel(mut m: MatrixMut) { + debug_assert!(m.is_square()); + debug_assert!(m.rows().is_power_of_two()); + let size = m.rows(); + if size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (a, b, c, d) = m.split_quadrants(n, n); + transpose_square_non_parallel(a); + transpose_square_non_parallel(d); + transpose_square_swap_non_parallel(b, c); + } else { + for i in 0..size { + for j in (i + 1)..size { + // unsafe needed due to lack of bounds-check by swap. We are guaranteed that (i,j) and (j,i) are within the bounds. + unsafe { + m.swap((i, j), (j, i)); + } + } + } + } +} + +/// Transpose and swap two square size matrices. Sizes must be equal and a power of two. +fn transpose_square_swap(a: MatrixMut, b: MatrixMut) { + #[cfg(feature = "parallel")] + transpose_square_swap_parallel(a, b); + #[cfg(not(feature = "parallel"))] + transpose_square_swap_non_parallel(a, b); +} + +/// Transpose and swap two square size matrices (parallel version). The size must be a power of two. +#[cfg(feature = "parallel")] +fn transpose_square_swap_parallel(mut a: MatrixMut, mut b: MatrixMut) { + debug_assert!(a.is_square()); + debug_assert_eq!(a.rows(), b.cols()); + debug_assert_eq!(a.cols(), b.rows()); + debug_assert!(is_power_of_two(a.rows())); + debug_assert!(workload_size::() >= 2); // otherwise, we would recurse even if size == 1. + let size = a.rows(); + if 2 * size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (aa, ab, ac, ad) = a.split_quadrants(n, n); + let (ba, bb, bc, bd) = b.split_quadrants(n, n); + + join( + || { + join( + || transpose_square_swap_parallel(aa, ba), + || transpose_square_swap_parallel(ab, bc), + ) + }, + || { + join( + || transpose_square_swap_parallel(ac, bb), + || transpose_square_swap_parallel(ad, bd), + ) + }, + ); + } else { + for i in 0..size { + for j in 0..size { + swap(&mut a[(i, j)], &mut b[(j, i)]) + } + } + } +} + +/// Transpose and swap two square size matrices, whose sizes are a power of two (non-parallel version) +fn transpose_square_swap_non_parallel(mut a: MatrixMut, mut b: MatrixMut) { + debug_assert!(a.is_square()); + debug_assert_eq!(a.rows(), b.cols()); + debug_assert_eq!(a.cols(), b.rows()); + debug_assert!(is_power_of_two(a.rows())); + debug_assert!(workload_size::() >= 2); // otherwise, we would recurse even if size == 1. + + let size = a.rows(); + if 2 * size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (aa, ab, ac, ad) = a.split_quadrants(n, n); + let (ba, bb, bc, bd) = b.split_quadrants(n, n); + transpose_square_swap_non_parallel(aa, ba); + transpose_square_swap_non_parallel(ab, bc); + transpose_square_swap_non_parallel(ac, bb); + transpose_square_swap_non_parallel(ad, bd); + } else { + for i in 0..size { + for j in 0..size { + swap(&mut a[(i, j)], &mut b[(j, i)]) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{super::utils::workload_size, *}; + + type Pair = (usize, usize); + type Triple = (usize, usize, usize); + + // create a vector (intended to be viewed as a matrix) whose (i,j)'th entry is the pair (i,j) itself. + // This is useful to debug transposition algorithms. + fn make_example_matrix(rows: usize, columns: usize) -> Vec { + let mut v: Vec = vec![(0, 0); rows * columns]; + let mut view = MatrixMut::from_mut_slice(&mut v, rows, columns); + for i in 0..rows { + for j in 0..columns { + view[(i, j)] = (i, j); + } + } + v + } + + // create a vector (intended to be viewed as a sequence of `instances` of matrices) where (i,j)'th entry of the `index`th matrix + // is the triple (index, i,j). + fn make_example_matrices(rows: usize, columns: usize, instances: usize) -> Vec { + let mut v: Vec = vec![(0, 0, 0); rows * columns * instances]; + for index in 0..instances { + let mut view = MatrixMut::from_mut_slice( + &mut v[rows * columns * index..rows * columns * (index + 1)], + rows, + columns, + ); + for i in 0..rows { + for j in 0..columns { + view[(i, j)] = (index, i, j); + } + } + } + v + } + + #[test] + fn test_transpose_copy() { + // iterate over both parallel and non-parallel implementation. + // Needs HRTB, otherwise it won't work. + #[allow(clippy::type_complexity)] + let mut funs: Vec<&dyn for<'a, 'b> Fn(MatrixMut<'a, Pair>, MatrixMut<'b, Pair>)> = vec![ + &transpose_copy_not_parallel::, + &transpose_copy::, + ]; + #[cfg(feature = "parallel")] + funs.push(&transpose_copy_parallel::); + + for f in funs { + let rows: usize = workload_size::() + 1; // intentionally not a power of two: The function is not described as only working for powers of two. + let columns: usize = 4; + let mut srcarray = make_example_matrix(rows, columns); + let mut dstarray: Vec<(usize, usize)> = vec![(0, 0); rows * columns]; + + let src1 = MatrixMut::::from_mut_slice(&mut srcarray[..], rows, columns); + let dst1 = MatrixMut::::from_mut_slice(&mut dstarray[..], columns, rows); + + f(src1, dst1); + let dst1 = MatrixMut::::from_mut_slice(&mut dstarray[..], columns, rows); + + for i in 0..rows { + for j in 0..columns { + assert_eq!(dst1[(j, i)], (i, j)); + } + } + } + } + + #[test] + fn test_transpose_square_swap() { + // iterate over parallel and non-parallel variants: + #[allow(clippy::type_complexity)] + let mut funs: Vec<&dyn for<'a> Fn(MatrixMut<'a, Triple>, MatrixMut<'a, Triple>)> = vec![ + &transpose_square_swap::, + &transpose_square_swap_non_parallel::, + ]; + #[cfg(feature = "parallel")] + funs.push(&transpose_square_swap_parallel::); + + for f in funs { + // Set rows manually. We want to be sure to trigger the actual recursion. + // (Computing this from workload_size was too much hassle.) + let rows = 1024; // workload_size::(); + assert!(rows * rows > 2 * workload_size::()); + + let examples: Vec = make_example_matrices(rows, rows, 2); + // Make copies for simplicity, because we borrow different parts. + let mut examples1 = Vec::from(&examples[0..rows * rows]); + let mut examples2 = Vec::from(&examples[rows * rows..2 * rows * rows]); + + let view1 = MatrixMut::from_mut_slice(&mut examples1, rows, rows); + let view2 = MatrixMut::from_mut_slice(&mut examples2, rows, rows); + for i in 0..rows { + for j in 0..rows { + assert_eq!(view1[(i, j)], (0, i, j)); + assert_eq!(view2[(i, j)], (1, i, j)); + } + } + f(view1, view2); + let view1 = MatrixMut::from_mut_slice(&mut examples1, rows, rows); + let view2 = MatrixMut::from_mut_slice(&mut examples2, rows, rows); + for i in 0..rows { + for j in 0..rows { + assert_eq!(view1[(i, j)], (1, j, i)); + assert_eq!(view2[(i, j)], (0, j, i)); + } + } + } + } + + #[test] + fn test_transpose_square() { + let mut funs: Vec<&dyn for<'a> Fn(MatrixMut<'a, _>)> = vec![ + &transpose_square::, + &transpose_square_parallel::, + ]; + #[cfg(feature = "parallel")] + funs.push(&transpose_square::); + for f in funs { + // Set rows manually. We want to be sure to trigger the actual recursion. + // (Computing this from workload_size was too much hassle.) + let size = 1024; + assert!(size * size > 2 * workload_size::()); + + let mut example = make_example_matrix(size, size); + let view = MatrixMut::from_mut_slice(&mut example, size, size); + f(view); + let view = MatrixMut::from_mut_slice(&mut example, size, size); + for i in 0..size { + for j in 0..size { + assert_eq!(view[(i, j)], (j, i)); + } + } + } + } + + #[test] + fn test_transpose() { + let size = 1024; + + // rectangular matrix: + let rows = size; + let cols = 16; + let mut example = make_example_matrix(rows, cols); + transpose(&mut example, rows, cols); + let view = MatrixMut::from_mut_slice(&mut example, cols, rows); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (j, i)); + } + } + + // square matrix: + let rows = size; + let cols = size; + let mut example = make_example_matrix(rows, cols); + transpose(&mut example, rows, cols); + let view = MatrixMut::from_mut_slice(&mut example, cols, rows); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (j, i)); + } + } + + // 20 rectangular matrices: + let number_of_matrices = 20; + let rows = size; + let cols = 16; + let mut example = make_example_matrices(rows, cols, number_of_matrices); + transpose(&mut example, rows, cols); + for index in 0..number_of_matrices { + let view = MatrixMut::from_mut_slice( + &mut example[index * rows * cols..(index + 1) * rows * cols], + cols, + rows, + ); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (index, j, i)); + } + } + } + + // 20 square matrices: + let number_of_matrices = 20; + let rows = size; + let cols = size; + let mut example = make_example_matrices(rows, cols, number_of_matrices); + transpose(&mut example, rows, cols); + for index in 0..number_of_matrices { + let view = MatrixMut::from_mut_slice( + &mut example[index * rows * cols..(index + 1) * rows * cols], + cols, + rows, + ); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (index, j, i)); + } + } + } + } +} diff --git a/whir/src/ntt/utils.rs b/whir/src/ntt/utils.rs new file mode 100644 index 000000000..c2d7f18f5 --- /dev/null +++ b/whir/src/ntt/utils.rs @@ -0,0 +1,143 @@ +/// Target single-thread workload size for `T`. +/// Should ideally be a multiple of a cache line (64 bytes) +/// and close to the L1 cache size (32 KB). +pub const fn workload_size() -> usize { + const CACHE_SIZE: usize = 1 << 15; + CACHE_SIZE / size_of::() +} + +/// Cast a slice into chunks of size N. +/// +/// TODO: Replace with `slice::as_chunks` when stable. +pub fn as_chunks_exact_mut(slice: &mut [T]) -> &mut [[T; N]] { + assert!(N != 0, "chunk size must be non-zero"); + assert_eq!( + slice.len() % N, + 0, + "slice length must be a multiple of chunk size" + ); + // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length + let new_len = slice.len() / N; + // SAFETY: We cast a slice of `new_len * N` elements into + // a slice of `new_len` many `N` elements chunks. + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) } +} + +/// Compute the largest factor of n that is <= sqrt(n). +/// Assumes n is of the form 2^k * {1,3,9}. +pub fn sqrt_factor(n: usize) -> usize { + let twos = n.trailing_zeros(); + match n >> twos { + 1 => 1 << (twos / 2), + 3 | 9 => 3 << (twos / 2), + _ => panic!(), + } +} + +/// Least common multiple. +/// +/// Note that lcm(0,0) will panic (rather than give the correct answer 0). +pub fn lcm(a: usize, b: usize) -> usize { + a * (b / gcd(a, b)) +} + +/// Greatest common divisor. +pub fn gcd(mut a: usize, mut b: usize) -> usize { + while b != 0 { + (a, b) = (b, a % b); + } + a +} + +#[cfg(test)] +mod tests { + use super::{as_chunks_exact_mut, gcd, lcm, sqrt_factor}; + + #[test] + fn test_gcd() { + assert_eq!(gcd(4, 6), 2); + assert_eq!(gcd(0, 4), 4); + assert_eq!(gcd(4, 0), 4); + assert_eq!(gcd(1, 1), 1); + assert_eq!(gcd(64, 16), 16); + assert_eq!(gcd(81, 9), 9); + assert_eq!(gcd(0, 0), 0); + } + + #[test] + fn test_lcm() { + assert_eq!(lcm(5, 6), 30); + assert_eq!(lcm(3, 7), 21); + assert_eq!(lcm(0, 10), 0); + } + #[test] + fn test_sqrt_factor() { + // naive brute-force search for largest divisor up to sqrt n. + // This is not supposed to be efficient, but optimized for "ease of convincing yourself it's correct (provided none of the asserts trigger)". + fn get_largest_divisor_up_to_sqrt(x: usize) -> usize { + if x == 0 { + return 0; + } + let mut result = 1; + let isqrt_of_x: usize = { + // use x.isqrt() once this is stabilized. That would be MUCH simpler. + + assert!(x < (1 << f64::MANTISSA_DIGITS)); // guarantees that each of {x, floor(sqrt(x)), ceil(sqrt(x))} can be represented exactly by f64. + let x_as_float = x as f64; + // sqrt is guaranteed to be the exact result, then rounded. Due to the above assert, the rounded value is between floor(sqrt(x)) and ceil(sqrt(x)). + let sqrt_x = x_as_float.sqrt(); + // We return sqrt_x, rounded to 0; for correctness, we need to rule out that we rounded from a non-integer up to the integer ceil(sqrt(x)). + if sqrt_x.fract() == 0.0 { + assert!(sqrt_x * sqrt_x == x_as_float); + } + unsafe { sqrt_x.to_int_unchecked() } + }; + for i in 1..=isqrt_of_x { + if x % i == 0 { + result = i; + } + } + result + } + + for i in 0..10 { + assert_eq!(sqrt_factor(1 << i), get_largest_divisor_up_to_sqrt(1 << i)); + } + for i in 0..10 { + assert_eq!(sqrt_factor(1 << i), get_largest_divisor_up_to_sqrt(1 << i)); + } + + for i in 0..10 { + assert_eq!(sqrt_factor(1 << i), get_largest_divisor_up_to_sqrt(1 << i)); + } + } + + #[test] + fn test_as_chunks_exact_mut() { + let v = &mut [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + assert_eq!(as_chunks_exact_mut::<_, 12>(v), &[[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 + ]]); + assert_eq!(as_chunks_exact_mut::<_, 6>(v), &[[1, 2, 3, 4, 5, 6], [ + 7, 8, 9, 10, 11, 12 + ]]); + assert_eq!(as_chunks_exact_mut::<_, 1>(v), &[ + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9], + [10], + [11], + [12] + ]); + let should_not_work = std::panic::catch_unwind(|| { + as_chunks_exact_mut::<_, 2>(&mut [1, 2, 3]); + }); + assert!(should_not_work.is_err()) + } +} diff --git a/whir/src/ntt/wavelet.rs b/whir/src/ntt/wavelet.rs new file mode 100644 index 000000000..88ed2aea6 --- /dev/null +++ b/whir/src/ntt/wavelet.rs @@ -0,0 +1,90 @@ +use super::{transpose, utils::workload_size}; +use ark_ff::Field; +use std::cmp::max; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Fast Wavelet Transform. +/// +/// The input slice must have a length that is a power of two. +/// Recursively applies the kernel +/// [1 0] +/// [1 1] +pub fn wavelet_transform(values: &mut [F]) { + debug_assert!(values.len().is_power_of_two()); + wavelet_transform_batch(values, values.len()) +} + +pub fn wavelet_transform_batch(values: &mut [F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + debug_assert!(size.is_power_of_two()); + #[cfg(feature = "parallel")] + if values.len() > workload_size::() && values.len() != size { + // Multiple wavelet transforms, compute in parallel. + // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. + let workload_size = size * max(1, workload_size::() / size); + return values.par_chunks_mut(workload_size).for_each(|values| { + wavelet_transform_batch(values, size); + }); + } + match size { + 0 | 1 => {} + 2 => { + for v in values.chunks_exact_mut(2) { + v[1] += v[0] + } + } + 4 => { + for v in values.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + } + 8 => { + for v in values.chunks_exact_mut(8) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + v[5] += v[4]; + v[7] += v[6]; + v[6] += v[4]; + v[7] += v[5]; + v[4] += v[0]; + v[5] += v[1]; + v[6] += v[2]; + v[7] += v[3]; + } + } + 16 => { + for v in values.chunks_exact_mut(16) { + for v in v.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + let (a, v) = v.split_at_mut(4); + let (b, v) = v.split_at_mut(4); + let (c, d) = v.split_at_mut(4); + for i in 0..4 { + b[i] += a[i]; + d[i] += c[i]; + c[i] += a[i]; + d[i] += b[i]; + } + } + } + n => { + let n1 = 1 << (n.trailing_zeros() / 2); + let n2 = n / n1; + wavelet_transform_batch(values, n1); + transpose(values, n2, n1); + wavelet_transform_batch(values, n2); + transpose(values, n1, n2); + } + } +} diff --git a/whir/src/parameters.rs b/whir/src/parameters.rs new file mode 100644 index 000000000..5f806efa2 --- /dev/null +++ b/whir/src/parameters.rs @@ -0,0 +1,213 @@ +use std::{fmt::Display, marker::PhantomData, str::FromStr}; + +use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; +use serde::{Deserialize, Serialize}; + +pub fn default_max_pow(num_variables: usize, log_inv_rate: usize) -> usize { + num_variables + log_inv_rate - 3 +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum SoundnessType { + UniqueDecoding, + ProvableList, + ConjectureList, +} + +impl Display for SoundnessType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match &self { + SoundnessType::ProvableList => "ProvableList", + SoundnessType::ConjectureList => "ConjectureList", + SoundnessType::UniqueDecoding => "UniqueDecoding", + }) + } +} + +impl FromStr for SoundnessType { + type Err = String; + fn from_str(s: &str) -> Result { + if s == "ProvableList" { + Ok(SoundnessType::ProvableList) + } else if s == "ConjectureList" { + Ok(SoundnessType::ConjectureList) + } else if s == "UniqueDecoding" { + Ok(SoundnessType::UniqueDecoding) + } else { + Err(format!("Invalid soundness specification: {}", s)) + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct MultivariateParameters { + pub(crate) num_variables: usize, + _field: PhantomData, +} + +impl MultivariateParameters { + pub fn new(num_variables: usize) -> Self { + Self { + num_variables, + _field: PhantomData, + } + } +} + +impl Display for MultivariateParameters { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Number of variables: {}", self.num_variables) + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum FoldType { + Naive, + ProverHelps, +} + +impl FromStr for FoldType { + type Err = String; + fn from_str(s: &str) -> Result { + if s == "Naive" { + Ok(FoldType::Naive) + } else if s == "ProverHelps" { + Ok(FoldType::ProverHelps) + } else { + Err(format!("Invalid fold type specification: {}", s)) + } + } +} + +impl Display for FoldType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match self { + FoldType::Naive => "Naive", + FoldType::ProverHelps => "ProverHelps", + }) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum FoldingFactor { + Constant(usize), // Use the same folding factor for all rounds + ConstantFromSecondRound(usize, usize), /* Use the same folding factor for all rounds, but the first round uses a different folding factor */ +} + +impl FoldingFactor { + pub fn at_round(&self, round: usize) -> usize { + match self { + FoldingFactor::Constant(factor) => *factor, + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + if round == 0 { + *first_round_factor + } else { + *factor + } + } + } + } + + pub fn check_validity(&self, _num_variables: usize) -> Result<(), String> { + match self { + FoldingFactor::Constant(factor) => { + if *factor == 0 { + // We should at least fold some time + Err("Folding factor shouldn't be zero.".to_string()) + } else { + Ok(()) + } + } + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + if *factor == 0 || *first_round_factor == 0 { + // We should at least fold some time + Err("Folding factor shouldn't be zero.".to_string()) + } else { + Ok(()) + } + } + } + } + + /// Compute the number of WHIR rounds and the number of rounds in the final + /// sumcheck. + pub fn compute_number_of_rounds(&self, num_variables: usize) -> (usize, usize) { + match self { + FoldingFactor::Constant(factor) => { + // It's checked that factor > 0 and factor <= num_variables + let final_sumcheck_rounds = num_variables % factor; + ( + (num_variables - final_sumcheck_rounds) / factor - 1, + final_sumcheck_rounds, + ) + } + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + let nv_except_first_round = num_variables - *first_round_factor; + if nv_except_first_round < *factor { + // This case is equivalent to Constant(first_round_factor) + return (0, nv_except_first_round); + } + let final_sumcheck_rounds = nv_except_first_round % *factor; + ( + // No need to minus 1 because the initial round is already + // excepted out + (nv_except_first_round - final_sumcheck_rounds) / factor, + final_sumcheck_rounds, + ) + } + } + } + + /// Compute folding_factor(0) + ... + folding_factor(n_rounds) + pub fn total_number(&self, n_rounds: usize) -> usize { + match self { + FoldingFactor::Constant(factor) => { + // It's checked that factor > 0 and factor <= num_variables + factor * (n_rounds + 1) + } + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + first_round_factor + factor * n_rounds + } + } + } +} + +#[derive(Clone)] +pub struct WhirParameters +where + MerkleConfig: Config, +{ + pub initial_statement: bool, + pub starting_log_inv_rate: usize, + pub folding_factor: FoldingFactor, + pub soundness_type: SoundnessType, + pub security_level: usize, + pub pow_bits: usize, + + pub fold_optimisation: FoldType, + + // PoW parameters + pub _pow_parameters: PhantomData, + + // Merkle tree parameters + pub leaf_hash_params: LeafParam, + pub two_to_one_params: TwoToOneParam, +} + +impl Display for WhirParameters +where + MerkleConfig: Config, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Targeting {}-bits of security with {}-bits of PoW - soundness: {:?}", + self.security_level, self.pow_bits, self.soundness_type + )?; + writeln!( + f, + "Starting rate: 2^-{}, folding_factor: {:?}, fold_opt_type: {}", + self.starting_log_inv_rate, self.folding_factor, self.fold_optimisation, + ) + } +} diff --git a/whir/src/poly_utils/coeffs.rs b/whir/src/poly_utils/coeffs.rs new file mode 100644 index 000000000..649977743 --- /dev/null +++ b/whir/src/poly_utils/coeffs.rs @@ -0,0 +1,467 @@ +use super::{MultilinearPoint, evals::EvaluationsList, hypercube::BinaryHypercubePoint}; +use crate::ntt::wavelet_transform; +use ark_ff::Field; +use ark_poly::{DenseUVPolynomial, Polynomial, univariate::DensePolynomial}; +use serde::{Deserialize, Serialize}; +#[cfg(feature = "parallel")] +use { + rayon::{join, prelude::*}, + std::mem::size_of, +}; + +/// A CoefficientList models a (multilinear) polynomial in `num_variable` variables in coefficient form. +/// +/// The order of coefficients follows the following convention: coeffs[j] corresponds to the monomial +/// determined by the binary decomposition of j with an X_i-variable present if the +/// i-th highest-significant bit among the `num_variables` least significant bits is set. +/// +/// e.g. is `num_variables` is 3 with variables X_0, X_1, X_2, then +/// - coeffs[0] is the coefficient of 1 +/// - coeffs[1] is the coefficient of X_2 +/// - coeffs[2] is the coefficient of X_1 +/// - coeffs[4] is the coefficient of X_0 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CoefficientList { + coeffs: Vec, /* list of coefficients. For multilinear polynomials, we have coeffs.len() == 1 << num_variables. */ + num_variables: usize, // number of variables +} + +impl CoefficientList +where + F: Field, +{ + fn coeff_at(&self, index: usize) -> F { + self.coeffs[index] + } + + pub fn pad_to_num_vars(&mut self, num_vars: usize) { + if self.num_variables < num_vars { + let pad = (1usize << num_vars) - (1usize << self.num_variables); + self.coeffs.extend(vec![F::ZERO; pad]); + self.num_variables = num_vars; + } + } + + /// Linearly combine the given polynomials using the given coefficients + pub fn combine(polys: &[Self], coeffs: &[F]) -> Self { + Self::new( + (0..polys[0].coeffs.len()) + .into_par_iter() + .map(|i| { + let mut acc = F::ZERO; + for (poly, coeff) in polys.iter().zip(coeffs) { + acc += poly.coeff_at(i) * coeff; + } + acc + }) + .collect(), + ) + } + + /// Evaluate the given polynomial at `point` from {0,1}^n + pub fn evaluate_hypercube(&self, point: BinaryHypercubePoint) -> F { + assert_eq!(self.coeffs.len(), 1 << self.num_variables); + assert!(point.0 < (1 << self.num_variables)); + // TODO: Optimized implementation + self.evaluate(&MultilinearPoint::from_binary_hypercube_point( + point, + self.num_variables, + )) + } + + /// Evaluate the given polynomial at `point` from F^n. + pub fn evaluate(&self, point: &MultilinearPoint) -> F { + assert_eq!(self.num_variables, point.n_variables()); + eval_multivariate(&self.coeffs, &point.0) + } + + #[inline] + fn eval_extension>(coeff: &[F], eval: &[E], scalar: E) -> E { + // explicit "return" just to simplify static code-analyzers' tasks (that can't figure out the cfg's are disjoint) + #[cfg(not(feature = "parallel"))] + return Self::eval_extension_nonparallel(coeff, eval, scalar); + #[cfg(feature = "parallel")] + return Self::eval_extension_parallel(coeff, eval, scalar); + } + + // NOTE (Gotti): This algorithm uses 2^{n+1}-1 multiplications for a polynomial in n variables. + // You could do with 2^{n}-1 by just doing a + x * b (and not forwarding scalar through the recursion at all). + // The difference comes from multiplications by E::ONE at the leaves of the recursion tree. + + // recursive helper function for polynomial evaluation: + // Note that eval(coeffs, [X_0, X1,...]) = eval(coeffs_left, [X_1,...]) + X_0 * eval(coeffs_right, [X_1,...]) + + /// Recursively compute scalar * poly_eval(coeffs;eval) where poly_eval interprets coeffs as a polynomial and eval are the evaluation points. + fn eval_extension_nonparallel>( + coeff: &[F], + eval: &[E], + scalar: E, + ) -> E { + debug_assert_eq!(coeff.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = coeff.split_at(coeff.len() / 2); + let a = Self::eval_extension_nonparallel(low, tail, scalar); + let b = Self::eval_extension_nonparallel(high, tail, scalar * x); + a + b + } else { + scalar.mul_by_base_prime_field(&coeff[0]) + } + } + + #[cfg(feature = "parallel")] + fn eval_extension_parallel>( + coeff: &[F], + eval: &[E], + scalar: E, + ) -> E { + const PARALLEL_THRESHOLD: usize = 10; + debug_assert_eq!(coeff.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = coeff.split_at(coeff.len() / 2); + if tail.len() > PARALLEL_THRESHOLD { + let (a, b) = rayon::join( + || Self::eval_extension_parallel(low, tail, scalar), + || Self::eval_extension_parallel(high, tail, scalar * x), + ); + a + b + } else { + Self::eval_extension_nonparallel(low, tail, scalar) + + Self::eval_extension_nonparallel(high, tail, scalar * x) + } + } else { + scalar.mul_by_base_prime_field(&coeff[0]) + } + } + + /// Evaluate self at `point`, where `point` is from a field extension extending the field over which the polynomial `self` is defined. + /// + /// Note that we only support the case where F is a prime field. + pub fn evaluate_at_extension>( + &self, + point: &MultilinearPoint, + ) -> E { + assert_eq!(self.num_variables, point.n_variables()); + Self::eval_extension(&self.coeffs, &point.0, E::ONE) + } + + /// Interprets self as a univariate polynomial (with coefficients of X^i in order of ascending i) and evaluates it at each point in `points`. + /// We return the vector of evaluations. + /// + /// NOTE: For the `usual` mapping between univariate and multilinear polynomials, the coefficient ordering is such that + /// for a single point x, we have (extending notation to a single point) + /// self.evaluate_at_univariate(x) == self.evaluate([x^(2^n), x^(2^{n-1}), ..., x^2, x]) + pub fn evaluate_at_univariate(&self, points: &[F]) -> Vec { + // DensePolynomial::from_coefficients_slice converts to a dense univariate polynomial. + // The coefficient order is "coefficient of 1 first". + let univariate = DensePolynomial::from_coefficients_slice(&self.coeffs); + points + .iter() + .map(|point| univariate.evaluate(point)) + .collect() + } +} + +impl CoefficientList { + pub fn new(coeffs: Vec) -> Self { + let len = coeffs.len(); + assert!(len.is_power_of_two()); + let num_variables = len.ilog2(); + + CoefficientList { + coeffs, + num_variables: num_variables as usize, + } + } + + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } + + pub fn num_variables(&self) -> usize { + self.num_variables + } + + pub fn num_coeffs(&self) -> usize { + self.coeffs.len() + } + + /// Map the polynomial `self` from F[X_1,...,X_n] to E[X_1,...,X_n], where E is a field extension of F. + /// + /// Note that this is currently restricted to the case where F is a prime field. + pub fn to_extension>(self) -> CoefficientList { + CoefficientList::new( + self.coeffs + .into_iter() + .map(E::from_base_prime_field) + .collect(), + ) + } +} + +/// Multivariate evaluation in coefficient form. +fn eval_multivariate(coeffs: &[F], point: &[F]) -> F { + debug_assert_eq!(coeffs.len(), 1 << point.len()); + match point { + [] => coeffs[0], + [x] => coeffs[0] + coeffs[1] * x, + [x0, x1] => { + let b0 = coeffs[0] + coeffs[1] * x1; + let b1 = coeffs[2] + coeffs[3] * x1; + b0 + b1 * x0 + } + [x0, x1, x2] => { + let b00 = coeffs[0] + coeffs[1] * x2; + let b01 = coeffs[2] + coeffs[3] * x2; + let b10 = coeffs[4] + coeffs[5] * x2; + let b11 = coeffs[6] + coeffs[7] * x2; + let b0 = b00 + b01 * x1; + let b1 = b10 + b11 * x1; + b0 + b1 * x0 + } + [x0, x1, x2, x3] => { + let b000 = coeffs[0] + coeffs[1] * x3; + let b001 = coeffs[2] + coeffs[3] * x3; + let b010 = coeffs[4] + coeffs[5] * x3; + let b011 = coeffs[6] + coeffs[7] * x3; + let b100 = coeffs[8] + coeffs[9] * x3; + let b101 = coeffs[10] + coeffs[11] * x3; + let b110 = coeffs[12] + coeffs[13] * x3; + let b111 = coeffs[14] + coeffs[15] * x3; + let b00 = b000 + b001 * x2; + let b01 = b010 + b011 * x2; + let b10 = b100 + b101 * x2; + let b11 = b110 + b111 * x2; + let b0 = b00 + b01 * x1; + let b1 = b10 + b11 * x1; + b0 + b1 * x0 + } + [x, tail @ ..] => { + let (b0t, b1t) = coeffs.split_at(coeffs.len() / 2); + #[cfg(not(feature = "parallel"))] + let (b0t, b1t) = (eval_multivariate(b0t, tail), eval_multivariate(b1t, tail)); + #[cfg(feature = "parallel")] + let (b0t, b1t) = { + let work_size: usize = (1 << 15) / size_of::(); + if coeffs.len() > work_size { + join( + || eval_multivariate(b0t, tail), + || eval_multivariate(b1t, tail), + ) + } else { + (eval_multivariate(b0t, tail), eval_multivariate(b1t, tail)) + } + }; + b0t + b1t * x + } + } +} + +impl CoefficientList +where + F: Field, +{ + /// fold folds the polynomial at the provided folding_randomness. + /// + /// Namely, when self is interpreted as a multi-linear polynomial f in X_0, ..., X_{n-1}, + /// it partially evaluates f at the provided `folding_randomness`. + /// Our ordering convention is to evaluate at the higher indices, i.e. we return f(X_0,X_1,..., folding_randomness[0], folding_randomness[1],...) + pub fn fold(&self, folding_randomness: &MultilinearPoint) -> Self { + let folding_factor = folding_randomness.n_variables(); + #[cfg(not(feature = "parallel"))] + let coeffs = self + .coeffs + .chunks_exact(1 << folding_factor) + .map(|coeffs| eval_multivariate(coeffs, &folding_randomness.0)) + .collect(); + #[cfg(feature = "parallel")] + let coeffs = self + .coeffs + .par_chunks_exact(1 << folding_factor) + .map(|coeffs| eval_multivariate(coeffs, &folding_randomness.0)) + .collect(); + + CoefficientList { + coeffs, + num_variables: self.num_variables() - folding_factor, + } + } +} + +impl From> for DensePolynomial +where + F: Field, +{ + fn from(value: CoefficientList) -> Self { + DensePolynomial::from_coefficients_vec(value.coeffs) + } +} + +impl From> for CoefficientList +where + F: Field, +{ + fn from(value: DensePolynomial) -> Self { + CoefficientList::new(value.coeffs) + } +} + +impl From> for EvaluationsList +where + F: Field, +{ + fn from(value: CoefficientList) -> Self { + let mut evals = value.coeffs; + wavelet_transform(&mut evals); + EvaluationsList::new(evals) + } +} + +// Previous recursive version +// impl From> for EvaluationsList +// where +// F: Field, +// { +// fn from(value: CoefficientList) -> Self { +// let num_coeffs = value.num_coeffs(); +// Base case +// if num_coeffs == 1 { +// return EvaluationsList::new(value.coeffs); +// } +// +// let half_coeffs = num_coeffs / 2; +// +// Left is polynomial with last variable set to 0 +// let mut left = Vec::with_capacity(half_coeffs); +// +// Right is polynomial with last variable set to 1 +// let mut right = Vec::with_capacity(half_coeffs); +// +// for i in 0..half_coeffs { +// left.push(value.coeffs[2 * i]); +// right.push(value.coeffs[2 * i] + value.coeffs[2 * i + 1]); +// } +// +// let left_poly = CoefficientList { +// coeffs: left, +// num_variables: value.num_variables - 1, +// }; +// let right_poly = CoefficientList { +// coeffs: right, +// num_variables: value.num_variables - 1, +// }; +// +// Compute evaluation of right and left +// let left_eval = EvaluationsList::from(left_poly); +// let right_eval = EvaluationsList::from(right_poly); +// +// Combine +// let mut evaluation_list = Vec::with_capacity(num_coeffs); +// for i in 0..half_coeffs { +// evaluation_list.push(left_eval[i]); +// evaluation_list.push(right_eval[i]); +// } +// +// EvaluationsList::new(evaluation_list) +// } +// } + +#[cfg(test)] +mod tests { + use ark_poly::{Polynomial, univariate::DensePolynomial}; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList}, + }; + + type F = Field64; + + #[test] + fn test_evaluation_conversion() { + let coeffs = vec![F::from(22), F::from(5), F::from(10), F::from(97)]; + let coeffs_list = CoefficientList::new(coeffs.clone()); + + let evaluations = EvaluationsList::from(coeffs_list); + + assert_eq!(evaluations[0], coeffs[0]); + assert_eq!(evaluations[1], coeffs[0] + coeffs[1]); + assert_eq!(evaluations[2], coeffs[0] + coeffs[2]); + assert_eq!( + evaluations[3], + coeffs[0] + coeffs[1] + coeffs[2] + coeffs[3] + ); + } + + #[test] + fn test_folding() { + let coeffs = vec![F::from(22), F::from(5), F::from(00), F::from(00)]; + let coeffs_list = CoefficientList::new(coeffs); + + let alpha = F::from(100); + let beta = F::from(32); + + let folded = coeffs_list.fold(&MultilinearPoint(vec![beta])); + + assert_eq!( + coeffs_list.evaluate(&MultilinearPoint(vec![alpha, beta])), + folded.evaluate(&MultilinearPoint(vec![alpha])) + ) + } + + #[test] + fn test_folding_and_evaluation() { + let num_variables = 10; + let coeffs = (0..(1 << num_variables)).map(F::from).collect(); + let coeffs_list = CoefficientList::new(coeffs); + + let randomness: Vec<_> = (0..num_variables).map(|i| F::from(35 * i as u64)).collect(); + for k in 0..num_variables { + let fold_part = randomness[0..k].to_vec(); + let eval_part = randomness[k..randomness.len()].to_vec(); + + let fold_random = MultilinearPoint(fold_part.clone()); + let eval_point = MultilinearPoint([eval_part.clone(), fold_part].concat()); + + let folded = coeffs_list.fold(&fold_random); + assert_eq!( + folded.evaluate(&MultilinearPoint(eval_part)), + coeffs_list.evaluate(&eval_point) + ); + } + } + + #[test] + fn test_evaluation_mv() { + let polynomial = vec![ + F::from(0), + F::from(1), + F::from(2), + F::from(3), + F::from(4), + F::from(5), + F::from(6), + F::from(7), + F::from(8), + F::from(9), + F::from(10), + F::from(11), + F::from(12), + F::from(13), + F::from(14), + F::from(15), + ]; + + let mv_poly = CoefficientList::new(polynomial); + let uv_poly: DensePolynomial<_> = mv_poly.clone().into(); + + let eval_point = F::from(4999); + assert_eq!( + uv_poly.evaluate(&F::from(1)), + F::from((0..=15).sum::()) + ); + assert_eq!( + uv_poly.evaluate(&eval_point), + mv_poly.evaluate(&MultilinearPoint::expand_from_univariate(eval_point, 4)) + ) + } +} diff --git a/whir/src/poly_utils/evals.rs b/whir/src/poly_utils/evals.rs new file mode 100644 index 000000000..d65f92954 --- /dev/null +++ b/whir/src/poly_utils/evals.rs @@ -0,0 +1,95 @@ +use std::ops::Index; + +use ark_ff::Field; + +use super::{MultilinearPoint, sequential_lag_poly::LagrangePolynomialIterator}; + +/// An EvaluationsList models a multi-linear polynomial f in `num_variables` +/// unknowns, stored via their evaluations at {0,1}^{num_variables} +/// +/// `evals` stores the evaluation in lexicographic order. +#[derive(Debug, Clone)] +pub struct EvaluationsList { + evals: Vec, + num_variables: usize, +} + +impl EvaluationsList +where + F: Field, +{ + /// Constructs a EvaluationList from the given vector `eval` of evaluations. + /// + /// The provided `evals` is supposed to be the list of evaluations, where the ordering of evaluation points in {0,1}^n + /// is lexicographic. + pub fn new(evals: Vec) -> Self { + let len = evals.len(); + assert!(len.is_power_of_two()); + let num_variables = len.ilog2(); + + EvaluationsList { + evals, + num_variables: num_variables as usize, + } + } + + /// evaluate the polynomial at `point` + pub fn evaluate(&self, point: &MultilinearPoint) -> F { + if let Some(point) = point.to_hypercube() { + return self.evals[point.0]; + } + + let mut sum = F::ZERO; + for (b, lag) in LagrangePolynomialIterator::new(point) { + sum += lag * self.evals[b.0] + } + + sum + } + + pub fn evals(&self) -> &[F] { + &self.evals + } + + pub fn evals_mut(&mut self) -> &mut [F] { + &mut self.evals + } + + pub fn num_evals(&self) -> usize { + self.evals.len() + } + + pub fn num_variables(&self) -> usize { + self.num_variables + } +} + +impl Index for EvaluationsList { + type Output = F; + fn index(&self, index: usize) -> &Self::Output { + &self.evals[index] + } +} + +#[cfg(test)] +mod tests { + use crate::poly_utils::hypercube::BinaryHypercube; + + use super::*; + use ark_ff::*; + + type F = crate::crypto::fields::Field64; + + #[test] + fn test_evaluation() { + let evaluations_vec = vec![F::ZERO, F::ONE, F::ZERO, F::ONE]; + let evals = EvaluationsList::new(evaluations_vec.clone()); + + for i in BinaryHypercube::new(2) { + assert_eq!( + evaluations_vec[i.0], + evals.evaluate(&MultilinearPoint::from_binary_hypercube_point(i, 2)) + ); + } + } +} diff --git a/whir/src/poly_utils/fold.rs b/whir/src/poly_utils/fold.rs new file mode 100644 index 000000000..5d10d036f --- /dev/null +++ b/whir/src/poly_utils/fold.rs @@ -0,0 +1,223 @@ +use crate::{ntt::intt_batch, parameters::FoldType}; +use ark_ff::{FftField, Field}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Given the evaluation of f on the coset specified by coset_offset * +/// Compute the fold on that point +pub fn compute_fold( + answers: &[F], + folding_randomness: &[F], + mut coset_offset_inv: F, + mut coset_gen_inv: F, + two_inv: F, + folding_factor: usize, +) -> F { + let mut answers = answers.to_vec(); + + // We recursively compute the fold, rec is where it is + for rec in 0..folding_factor { + let offset = answers.len() / 2; + let mut new_answers = vec![F::ZERO; offset]; + let mut coset_index_inv = F::ONE; + for i in 0..offset { + let f_value_0 = answers[i]; + let f_value_1 = answers[i + offset]; + let point_inv = coset_offset_inv * coset_index_inv; + + let left = f_value_0 + f_value_1; + let right = point_inv * (f_value_0 - f_value_1); + + new_answers[i] = + two_inv * (left + folding_randomness[folding_randomness.len() - 1 - rec] * right); + coset_index_inv *= coset_gen_inv; + } + answers = new_answers; + + // Update for next one + coset_offset_inv = coset_offset_inv * coset_offset_inv; + coset_gen_inv = coset_gen_inv * coset_gen_inv; + } + + answers[0] +} + +pub fn restructure_evaluations( + mut stacked_evaluations: Vec, + fold_type: FoldType, + _domain_gen: F, + domain_gen_inv: F, + folding_factor: usize, +) -> Vec { + let folding_size = 1_u64 << folding_factor; + assert_eq!(stacked_evaluations.len() % (folding_size as usize), 0); + match fold_type { + FoldType::Naive => stacked_evaluations, + FoldType::ProverHelps => { + // TODO: This partially undoes the NTT transform from tne encoding. + // Maybe there is a way to not do the full transform in the first place. + + // Batch inverse NTTs + intt_batch(&mut stacked_evaluations, folding_size as usize); + + // Apply coset and size correction. + // Stacked evaluation at i is f(B_l) where B_l = w^i * + let size_inv = F::from(folding_size).inverse().unwrap(); + #[cfg(not(feature = "parallel"))] + { + let mut coset_offset_inv = F::ONE; + for answers in stacked_evaluations.chunks_exact_mut(folding_size as usize) { + let mut scale = size_inv; + for v in answers.iter_mut() { + *v *= scale; + scale *= coset_offset_inv; + } + coset_offset_inv *= domain_gen_inv; + } + } + #[cfg(feature = "parallel")] + stacked_evaluations + .par_chunks_exact_mut(folding_size as usize) + .enumerate() + .for_each_with(F::ZERO, |offset, (i, answers)| { + if *offset == F::ZERO { + *offset = domain_gen_inv.pow([i as u64]); + } else { + *offset *= domain_gen_inv; + } + let mut scale = size_inv; + for v in answers.iter_mut() { + *v *= scale; + scale *= &*offset; + } + }); + + stacked_evaluations + } + } +} + +#[cfg(test)] +mod tests { + use ark_ff::{FftField, Field}; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + utils::stack_evaluations, + }; + + use super::{compute_fold, restructure_evaluations}; + + type F = Field64; + + #[test] + fn test_folding() { + let num_variables = 5; + let num_coeffs = 1 << num_variables; + + let domain_size = 256; + let folding_factor = 3; // We fold in 8 + let folding_factor_exp = 1 << folding_factor; + + let poly = CoefficientList::new((0..num_coeffs).map(F::from).collect()); + + let root_of_unity = F::get_root_of_unity(domain_size).unwrap(); + + let index = 15; + let folding_randomness: Vec<_> = (0..folding_factor).map(|i| F::from(i as u64)).collect(); + + let coset_offset = root_of_unity.pow([index]); + let coset_gen = root_of_unity.pow([domain_size / folding_factor_exp]); + + // Evaluate the polynomial on the coset + let poly_eval: Vec<_> = (0..folding_factor_exp) + .map(|i| { + poly.evaluate(&MultilinearPoint::expand_from_univariate( + coset_offset * coset_gen.pow([i]), + num_variables, + )) + }) + .collect(); + + let fold_value = compute_fold( + &poly_eval, + &folding_randomness, + coset_offset.inverse().unwrap(), + coset_gen.inverse().unwrap(), + F::from(2).inverse().unwrap(), + folding_factor, + ); + + let truth_value = poly.fold(&MultilinearPoint(folding_randomness)).evaluate( + &MultilinearPoint::expand_from_univariate( + root_of_unity.pow([folding_factor_exp * index]), + 2, + ), + ); + + assert_eq!(fold_value, truth_value); + } + + #[test] + fn test_folding_optimised() { + let num_variables = 5; + let num_coeffs = 1 << num_variables; + + let domain_size = 256; + let folding_factor = 3; // We fold in 8 + let folding_factor_exp = 1 << folding_factor; + + let poly = CoefficientList::new((0..num_coeffs).map(F::from).collect()); + + let root_of_unity = F::get_root_of_unity(domain_size).unwrap(); + let root_of_unity_inv = root_of_unity.inverse().unwrap(); + + let folding_randomness: Vec<_> = (0..folding_factor).map(|i| F::from(i as u64)).collect(); + + // Evaluate the polynomial on the domain + let domain_evaluations: Vec<_> = (0..domain_size) + .map(|w| root_of_unity.pow([w])) + .map(|point| { + poly.evaluate(&MultilinearPoint::expand_from_univariate( + point, + num_variables, + )) + }) + .collect(); + + let unprocessed = stack_evaluations(domain_evaluations, folding_factor); + + let processed = restructure_evaluations( + unprocessed.clone(), + crate::parameters::FoldType::ProverHelps, + root_of_unity, + root_of_unity_inv, + folding_factor, + ); + + let num = domain_size / folding_factor_exp; + let coset_gen_inv = root_of_unity_inv.pow([num]); + + for index in 0..num { + let offset_inv = root_of_unity_inv.pow([index]); + let span = + (index * folding_factor_exp) as usize..((index + 1) * folding_factor_exp) as usize; + + let answer_unprocessed = compute_fold( + &unprocessed[span.clone()], + &folding_randomness, + offset_inv, + coset_gen_inv, + F::from(2).inverse().unwrap(), + folding_factor, + ); + + let answer_processed = CoefficientList::new(processed[span].to_vec()) + .evaluate(&MultilinearPoint(folding_randomness.clone())); + + assert_eq!(answer_processed, answer_unprocessed); + } + } +} diff --git a/whir/src/poly_utils/gray_lag_poly.rs b/whir/src/poly_utils/gray_lag_poly.rs new file mode 100644 index 000000000..e2749ffe8 --- /dev/null +++ b/whir/src/poly_utils/gray_lag_poly.rs @@ -0,0 +1,151 @@ +// NOTE: This is the one from Ron's + +use ark_ff::{Field, batch_inversion}; + +use super::{MultilinearPoint, hypercube::BinaryHypercubePoint}; + +pub struct LagrangePolynomialGray { + position_bin: usize, + position_gray: usize, + value: F, + precomputed: Vec, + num_variables: usize, +} + +impl LagrangePolynomialGray { + pub fn new(point: &MultilinearPoint) -> Self { + let num_variables = point.n_variables(); + // Limitation for bin hypercube + assert!(point.0.iter().all(|&p| p != F::ZERO && p != F::ONE)); + + // This is negated[i] = eq_poly(z_i, 0) = 1 - z_i + let negated_points: Vec<_> = point.0.iter().map(|z| F::ONE - z).collect(); + // This is points[i] = eq_poly(z_i, 1) = z_i + let points = point.0.to_vec(); + + let mut to_invert = [negated_points.clone(), points.clone()].concat(); + batch_inversion(&mut to_invert); + + let (denom_0, denom_1) = to_invert.split_at(num_variables); + + let mut precomputed = vec![F::ZERO; 2 * num_variables]; + for n in 0..num_variables { + precomputed[2 * n] = points[n] * denom_0[n]; + precomputed[2 * n + 1] = negated_points[n] * denom_1[n]; + } + + LagrangePolynomialGray { + position_gray: gray_encode(0), + position_bin: 0, + value: negated_points.into_iter().product(), + num_variables, + precomputed, + } + } +} + +pub fn gray_encode(integer: usize) -> usize { + (integer >> 1) ^ integer +} + +pub fn gray_decode(integer: usize) -> usize { + match integer { + 0 => 0, + _ => integer ^ gray_decode(integer >> 1), + } +} + +impl Iterator for LagrangePolynomialGray { + type Item = (BinaryHypercubePoint, F); + + fn next(&mut self) -> Option { + if self.position_bin >= (1 << self.num_variables) { + return None; + } + + let result = (BinaryHypercubePoint(self.position_gray), self.value); + + let prev = self.position_gray; + + self.position_bin += 1; + self.position_gray = gray_encode(self.position_bin); + + if self.position_bin < (1 << self.num_variables) { + let diff = prev ^ self.position_gray; + let i = (self.num_variables - 1) - diff.trailing_zeros() as usize; + let flip = (diff & self.position_gray == 0) as usize; + + self.value *= self.precomputed[2 * i + flip]; + } + + Some(result) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::{ + crypto::fields::Field64, + poly_utils::{ + MultilinearPoint, eq_poly, + gray_lag_poly::{LagrangePolynomialGray, gray_decode}, + hypercube::BinaryHypercubePoint, + }, + }; + + use super::gray_encode; + + type F = Field64; + + #[test] + fn test_gray_ordering() { + let values = [ + (0b0000, 0b0000), + (0b0001, 0b0001), + (0b0010, 0b0011), + (0b0011, 0b0010), + (0b0100, 0b0110), + (0b0101, 0b0111), + (0b0110, 0b0101), + (0b0111, 0b0100), + (0b1000, 0b1100), + (0b1001, 0b1101), + (0b1010, 0b1111), + (0b1011, 0b1110), + (0b1100, 0b1010), + (0b1101, 0b1011), + (0b1110, 0b1001), + (0b1111, 0b1000), + ]; + + for (bin, gray) in values { + assert_eq!(gray_encode(bin), gray); + assert_eq!(gray_decode(gray), bin); + } + } + + #[test] + fn test_gray_ordering_iterator() { + let point = MultilinearPoint(vec![F::from(2), F::from(3), F::from(4)]); + + for (i, (b, _)) in LagrangePolynomialGray::new(&point).enumerate() { + assert_eq!(b.0, gray_encode(i)); + } + } + + #[test] + fn test_gray() { + let point = MultilinearPoint(vec![F::from(2), F::from(3), F::from(4)]); + + let eq_poly_res: BTreeSet<_> = (0..(1 << 3)) + .map(BinaryHypercubePoint) + .map(|b| (b, eq_poly(&point, b))) + .collect(); + + let gray_res: BTreeSet<_> = LagrangePolynomialGray::new(&point).collect(); + + assert_eq!(eq_poly_res, gray_res); + } +} diff --git a/whir/src/poly_utils/hypercube.rs b/whir/src/poly_utils/hypercube.rs new file mode 100644 index 000000000..a5cd53d71 --- /dev/null +++ b/whir/src/poly_utils/hypercube.rs @@ -0,0 +1,40 @@ +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +// TODO (Gotti): Should pos rather be a u64? usize is platform-dependent, giving a platform-dependent limit on the number of variables. +// num_variables may be smaller as well. + +// NOTE: Conversion BinaryHypercube <-> MultilinearPoint is Big Endian, using only the num_variables least significant bits of the number stored inside BinaryHypercube. + +/// point on the binary hypercube {0,1}^n for some n. +/// +/// The point is encoded via the n least significant bits of a usize in big endian order and we do not store n. +pub struct BinaryHypercubePoint(pub usize); + +/// BinaryHypercube is an Iterator that is used to range over the points of the hypercube {0,1}^n, where n == `num_variables` +pub struct BinaryHypercube { + pos: usize, // current position, encoded via the bits of pos + num_variables: usize, // dimension of the hypercube +} + +impl BinaryHypercube { + pub fn new(num_variables: usize) -> Self { + debug_assert!(num_variables < usize::BITS as usize); // Note that we need strictly smaller, since some code would overflow otherwise. + BinaryHypercube { + pos: 0, + num_variables, + } + } +} + +impl Iterator for BinaryHypercube { + type Item = BinaryHypercubePoint; + + fn next(&mut self) -> Option { + let curr = self.pos; + if curr < (1 << self.num_variables) { + self.pos += 1; + Some(BinaryHypercubePoint(curr)) + } else { + None + } + } +} diff --git a/whir/src/poly_utils/mod.rs b/whir/src/poly_utils/mod.rs new file mode 100644 index 000000000..107574263 --- /dev/null +++ b/whir/src/poly_utils/mod.rs @@ -0,0 +1,325 @@ +use ark_ff::Field; +use rand::{ + Rng, RngCore, + distributions::{Distribution, Standard}, +}; + +use crate::utils::to_binary; + +use self::hypercube::BinaryHypercubePoint; + +pub mod coeffs; +pub mod evals; +pub mod fold; +pub mod gray_lag_poly; +pub mod hypercube; +pub mod sequential_lag_poly; +pub mod streaming_evaluation_helper; + +/// Point (x_1,..., x_n) in F^n for some n. Often, the x_i are binary. +/// For the latter case, we also have BinaryHypercubePoint. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MultilinearPoint(pub Vec); + +impl MultilinearPoint +where + F: Field, +{ + /// returns the number of variables. + pub fn n_variables(&self) -> usize { + self.0.len() + } + + // NOTE: Conversion BinaryHypercube <-> MultilinearPoint converts a + // multilinear point (x1,x2,...,x_n) into the number with bit-pattern 0...0 x_1 x_2 ... x_n, provided all x_i are in {0,1}. + // That means we pad zero bits in BinaryHypercube from the msb end and use big-endian for the actual conversion. + + /// Creates a MultilinearPoint from a BinaryHypercubePoint; the latter models the same thing, but is restricted to binary entries. + pub fn from_binary_hypercube_point(point: BinaryHypercubePoint, num_variables: usize) -> Self { + Self( + to_binary(point.0, num_variables) + .into_iter() + .map(|x| if x { F::ONE } else { F::ZERO }) + .collect(), + ) + } + + /// Converts to a BinaryHypercubePoint, provided the MultilinearPoint is actually in {0,1}^n. + pub fn to_hypercube(&self) -> Option { + let mut counter = 0; + for &coord in &self.0 { + if coord == F::ZERO { + counter <<= 1; + } else if coord == F::ONE { + counter = (counter << 1) + 1; + } else { + return None; + } + } + + Some(BinaryHypercubePoint(counter)) + } + + /// converts a univariate evaluation point into a multilinear one. + /// + /// Notably, consider the usual bijection + /// {multilinear polys in n variables} <-> {univariate polys of deg < 2^n} + /// f(x_1,...x_n) <-> g(y) := f(y^(2^(n-1), ..., y^4, y^2, y). + /// x_1^i_1 * ... *x_n^i_n <-> y^i, where (i_1,...,i_n) is the (big-endian) binary decomposition of i. + /// + /// expand_from_univariate maps the evaluation points to the multivariate domain, i.e. + /// f(expand_from_univariate(y)) == g(y). + /// in a way that is compatible with our endianness choices. + pub fn expand_from_univariate(point: F, num_variables: usize) -> Self { + let mut res = Vec::with_capacity(num_variables); + let mut cur = point; + for _ in 0..num_variables { + res.push(cur); + cur = cur * cur; + } + + // Reverse so higher power is first + res.reverse(); + + MultilinearPoint(res) + } +} + +/// creates a random MultilinearPoint of length `num_variables` using the RNG `rng`. +impl MultilinearPoint +where + Standard: Distribution, +{ + pub fn rand(rng: &mut impl RngCore, num_variables: usize) -> Self { + MultilinearPoint((0..num_variables).map(|_| rng.gen()).collect()) + } +} + +/// creates a MultilinearPoint of length 1 from a single field element +impl From for MultilinearPoint { + fn from(value: F) -> Self { + MultilinearPoint(vec![value]) + } +} + +/// Compute eq(coords,point), where eq is the equality polynomial, where point is binary. +/// +/// Recall that the equality polynomial eq(c, p) is defined as eq(c,p) == \prod_i c_i * p_i + (1-c_i)*(1-p_i). +/// Note that for fixed p, viewed as a polynomial in c, it is the interpolation polynomial associated to the evaluation point p in the evaluation set {0,1}^n. +pub fn eq_poly(coords: &MultilinearPoint, point: BinaryHypercubePoint) -> F +where + F: Field, +{ + let mut point = point.0; + let n_variables = coords.n_variables(); + assert!(point < (1 << n_variables)); // check that the lengths of coords and point match. + + let mut acc = F::ONE; + + for val in coords.0.iter().rev() { + let b = point % 2; + acc *= if b == 1 { *val } else { F::ONE - *val }; + point >>= 1; + } + + acc +} + +/// Compute eq(coords,point), where eq is the equality polynomial and where point is not neccessarily binary. +/// +/// Recall that the equality polynomial eq(c, p) is defined as eq(c,p) == \prod_i c_i * p_i + (1-c_i)*(1-p_i). +/// Note that for fixed p, viewed as a polynomial in c, it is the interpolation polynomial associated to the evaluation point p in the evaluation set {0,1}^n. +pub fn eq_poly_outside(coords: &MultilinearPoint, point: &MultilinearPoint) -> F +where + F: Field, +{ + assert_eq!(coords.n_variables(), point.n_variables()); + + let mut acc = F::ONE; + + for (&l, &r) in coords.0.iter().zip(&point.0) { + acc *= l * r + (F::ONE - l) * (F::ONE - r); + } + + acc +} + +// TODO: Precompute two_inv? +// Alternatively, compute it directly without the general (and slow) .inverse() map. + +/// Compute eq3(coords,point), where eq3 is the equality polynomial for {0,1,2}^n and point is interpreted as an element from {0,1,2}^n via (big Endian) ternary decomposition. +/// +/// eq3(coords, point) is the unique polynomial of degree <=2 in each variable, s.t. +/// for coords, point in {0,1,2}^n, we have: +/// eq3(coords,point) = 1 if coords == point and 0 otherwise. +pub fn eq_poly3(coords: &MultilinearPoint, mut point: usize) -> F +where + F: Field, +{ + let two = F::ONE + F::ONE; + let two_inv = two.inverse().unwrap(); + + let n_variables = coords.n_variables(); + assert!(point < 3usize.pow(n_variables as u32)); + + let mut acc = F::ONE; + + // Note: This iterates over the ternary decomposition least-significant trit(?) first. + // Since our convention is big endian, we reverse the order of coords to account for this. + for &val in coords.0.iter().rev() { + let b = point % 3; + acc *= match b { + 0 => (val - F::ONE) * (val - two) * two_inv, + 1 => val * (val - two) * (-F::ONE), + 2 => val * (val - F::ONE) * two_inv, + _ => unreachable!(), + }; + point /= 3; + } + + acc +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{eq_poly, eq_poly3, hypercube::BinaryHypercube}, + }; + + use super::{BinaryHypercubePoint, MultilinearPoint, coeffs::CoefficientList}; + + type F = Field64; + + #[test] + fn test_equality() { + let point = MultilinearPoint(vec![F::from(0), F::from(0)]); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b00)), F::from(1)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b01)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b10)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b11)), F::from(0)); + + let point = MultilinearPoint(vec![F::from(1), F::from(0)]); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b00)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b01)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b10)), F::from(1)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b11)), F::from(0)); + } + + #[test] + fn test_equality_again() { + let poly = CoefficientList::new(vec![F::from(35), F::from(97), F::from(10), F::from(32)]); + let point = MultilinearPoint(vec![F::from(42), F::from(36)]); + let eval = poly.evaluate(&point); + + assert_eq!( + eval, + BinaryHypercube::new(2) + .map( + |i| poly.evaluate(&MultilinearPoint::from_binary_hypercube_point(i, 2)) + * eq_poly(&point, i) + ) + .sum() + ); + } + + #[test] + fn test_equality3() { + let point = MultilinearPoint(vec![F::from(0), F::from(0)]); + + assert_eq!(eq_poly3(&point, 0), F::from(1)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(0)); + assert_eq!(eq_poly3(&point, 3), F::from(0)); + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(0)); + + let point = MultilinearPoint(vec![F::from(1), F::from(0)]); + + assert_eq!(eq_poly3(&point, 0), F::from(0)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(0)); + assert_eq!(eq_poly3(&point, 3), F::from(1)); // 3 corresponds to ternary (1,0) + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(0)); + + let point = MultilinearPoint(vec![F::from(0), F::from(2)]); + + assert_eq!(eq_poly3(&point, 0), F::from(0)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(1)); // 2 corresponds to ternary (0,2) + assert_eq!(eq_poly3(&point, 3), F::from(0)); + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(0)); + + let point = MultilinearPoint(vec![F::from(2), F::from(2)]); + + assert_eq!(eq_poly3(&point, 0), F::from(0)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(0)); + assert_eq!(eq_poly3(&point, 3), F::from(0)); + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(1)); // 8 corresponds to ternary (2,2) + } + + #[test] + #[should_panic] + fn test_equality_2() { + let coords = MultilinearPoint(vec![F::from(0), F::from(0)]); + + // implicit length of BinaryHypercubePoint is (at least) 3, exceeding lenth of coords + let _x = eq_poly(&coords, BinaryHypercubePoint(0b100)); + } + + #[test] + fn expand_from_univariate() { + let num_variables = 4; + + let point0 = MultilinearPoint::expand_from_univariate(F::from(0), num_variables); + let point1 = MultilinearPoint::expand_from_univariate(F::from(1), num_variables); + let point2 = MultilinearPoint::expand_from_univariate(F::from(2), num_variables); + + assert_eq!(point0.n_variables(), num_variables); + assert_eq!(point1.n_variables(), num_variables); + assert_eq!(point2.n_variables(), num_variables); + + assert_eq!( + MultilinearPoint::from_binary_hypercube_point(BinaryHypercubePoint(0), num_variables), + point0 + ); + + assert_eq!( + MultilinearPoint::from_binary_hypercube_point( + BinaryHypercubePoint((1 << num_variables) - 1), + num_variables + ), + point1 + ); + + assert_eq!( + MultilinearPoint(vec![F::from(256), F::from(16), F::from(4), F::from(2)]), + point2 + ); + } + + #[test] + fn from_hypercube_and_back() { + let hypercube_point = BinaryHypercubePoint(24); + assert_eq!( + Some(hypercube_point), + MultilinearPoint::::from_binary_hypercube_point(hypercube_point, 5).to_hypercube() + ); + } +} diff --git a/whir/src/poly_utils/sequential_lag_poly.rs b/whir/src/poly_utils/sequential_lag_poly.rs new file mode 100644 index 000000000..56284dedf --- /dev/null +++ b/whir/src/poly_utils/sequential_lag_poly.rs @@ -0,0 +1,174 @@ +// NOTE: This is the one from Blendy + +use ark_ff::Field; + +use super::{MultilinearPoint, hypercube::BinaryHypercubePoint}; + +/// There is an alternative (possibly more efficient) implementation that iterates over the x in Gray code ordering. +/// LagrangePolynomialIterator for a given multilinear n-dimensional `point` iterates over pairs (x, y) +/// where x ranges over all possible {0,1}^n +/// and y equals the product y_1 * ... * y_n where +/// +/// y_i = point[i] if x_i == 1 +/// y_i = 1-point[i] if x_i == 0 +/// +/// This means that y == eq_poly(point, x) +pub struct LagrangePolynomialIterator { + last_position: Option, /* the previously output BinaryHypercubePoint (encoded as usize). None before the first output. */ + point: Vec, /* stores a copy of the `point` given when creating the iterator. For easier(?) bit-fiddling, we store in in reverse order. */ + point_negated: Vec, /* stores the precomputed values 1-point[i] in the same ordering as point. */ + /// stack Stores the n+1 values (in order) 1, y_1, y_1*y_2, y_1*y_2*y_3, ..., y_1*...*y_n for the previously output y. + /// Before the first iteration (if last_position == None), it stores the values for the next (i.e. first) output instead. + stack: Vec, + num_variables: usize, // dimension +} + +impl LagrangePolynomialIterator { + pub fn new(point: &MultilinearPoint) -> Self { + let num_variables = point.0.len(); + + // Initialize a stack with capacity for messages/ message_hats and the identity element + let mut stack: Vec = Vec::with_capacity(point.0.len() + 1); + stack.push(F::ONE); + + let mut point = point.0.clone(); + let mut point_negated: Vec<_> = point.iter().map(|x| F::ONE - *x).collect(); + // Iterate over the message_hats, update the running product, and push it onto the stack + let mut running_product: F = F::ONE; + for point_neg in &point_negated { + running_product *= point_neg; + stack.push(running_product); + } + + point.reverse(); + point_negated.reverse(); + + // Return + Self { + num_variables, + point, + point_negated, + stack, + last_position: None, + } + } +} + +impl Iterator for LagrangePolynomialIterator { + type Item = (BinaryHypercubePoint, F); + // Iterator implementation for the struct + fn next(&mut self) -> Option { + // a) Check if this is the first iteration + if self.last_position.is_none() { + // Initialize last position + self.last_position = Some(0); + // Return the top of the stack + return Some((BinaryHypercubePoint(0), *self.stack.last().unwrap())); + } + + // b) Check if in the last iteration we finished iterating + if self.last_position.unwrap() + 1 >= 1 << self.num_variables { + return None; + } + + // c) Everything else, first get bit diff + let last_position = self.last_position.unwrap(); + let next_position = last_position + 1; + let bit_diff = last_position ^ next_position; + + // Determine the shared prefix of the most significant bits + let low_index_of_prefix = (bit_diff + 1).trailing_zeros() as usize; + + // Discard any stack values outside of this prefix + self.stack.truncate(self.stack.len() - low_index_of_prefix); + + // Iterate up to this prefix computing lag poly correctly + for bit_index in (0..low_index_of_prefix).rev() { + let last_element = self.stack.last().unwrap(); + let next_bit: bool = (next_position & (1 << bit_index)) != 0; + self.stack.push(match next_bit { + true => *last_element * self.point[bit_index], + false => *last_element * self.point_negated[bit_index], + }); + } + + // Don't forget to update the last position + self.last_position = Some(next_position); + + // Return the top of the stack + Some(( + BinaryHypercubePoint(next_position), + *self.stack.last().unwrap(), + )) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, eq_poly, hypercube::BinaryHypercubePoint}, + }; + + use super::LagrangePolynomialIterator; + + type F = Field64; + + #[test] + fn test_blendy() { + let one = F::from(1); + let (a, b) = (F::from(2), F::from(3)); + let point_1 = MultilinearPoint(vec![a, b]); + + let mut lag_iterator = LagrangePolynomialIterator::new(&point_1); + + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(0), (one - a) * (one - b)) + ); + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(1), (one - a) * b) + ); + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(2), a * (one - b)) + ); + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(3), a * b) + ); + assert_eq!(lag_iterator.next(), None); + } + + #[test] + fn test_blendy_2() { + let point = MultilinearPoint(vec![F::from(12), F::from(13), F::from(32)]); + + let mut last_b = None; + for (b, lag) in LagrangePolynomialIterator::new(&point) { + assert_eq!(eq_poly(&point, b), lag); + assert!(b.0 < 1 << 3); + last_b = Some(b); + } + assert_eq!(last_b, Some(BinaryHypercubePoint(7))); + } + + #[test] + fn test_blendy_3() { + let point = MultilinearPoint(vec![ + F::from(414151), + F::from(109849018), + F::from(33184190), + F::from(33184190), + F::from(33184190), + ]); + + let mut last_b = None; + for (b, lag) in LagrangePolynomialIterator::new(&point) { + assert_eq!(eq_poly(&point, b), lag); + last_b = Some(b); + } + assert_eq!(last_b, Some(BinaryHypercubePoint(31))); + } +} diff --git a/whir/src/poly_utils/streaming_evaluation_helper.rs b/whir/src/poly_utils/streaming_evaluation_helper.rs new file mode 100644 index 000000000..6e0a3a853 --- /dev/null +++ b/whir/src/poly_utils/streaming_evaluation_helper.rs @@ -0,0 +1,82 @@ +// NOTE: This is the one from Blendy adapted for streaming evals + +use ark_ff::Field; + +use super::{MultilinearPoint, hypercube::BinaryHypercubePoint}; + +pub struct TermPolynomialIterator { + last_position: Option, + point: Vec, + stack: Vec, + num_variables: usize, +} + +impl TermPolynomialIterator { + pub fn new(point: &MultilinearPoint) -> Self { + let num_variables = point.0.len(); + + // Initialize a stack with capacity for messages/ message_hats and the identity element + let stack: Vec = vec![F::ONE; point.0.len() + 1]; + + let mut point = point.0.clone(); + + point.reverse(); + + // Return + Self { + num_variables, + point, + stack, + last_position: None, + } + } +} + +impl Iterator for TermPolynomialIterator { + type Item = (BinaryHypercubePoint, F); + // Iterator implementation for the struct + fn next(&mut self) -> Option { + // a) Check if this is the first iteration + if self.last_position.is_none() { + // Initialize last position + self.last_position = Some(0); + // Return the top of the stack + return Some((BinaryHypercubePoint(0), *self.stack.last().unwrap())); + } + + // b) Check if in the last iteration we finished iterating + if self.last_position.unwrap() + 1 >= 1 << self.num_variables { + return None; + } + + // c) Everything else, first get bit diff + let last_position = self.last_position.unwrap(); + let next_position = last_position + 1; + let bit_diff = last_position ^ next_position; + + // Determine the shared prefix of the most significant bits + let low_index_of_prefix = (bit_diff + 1).trailing_zeros() as usize; + + // Discard any stack values outside of this prefix + self.stack.truncate(self.stack.len() - low_index_of_prefix); + + // Iterate up to this prefix computing lag poly correctly + for bit_index in (0..low_index_of_prefix).rev() { + let last_element = self.stack.last().unwrap(); + let next_bit: bool = (next_position & (1 << bit_index)) != 0; + self.stack.push(match next_bit { + true => *last_element * self.point[bit_index], + false => *last_element, + }); + } + + // Don't forget to update the last position + self.last_position = Some(next_position); + + // Return the top of the stack + Some(( + BinaryHypercubePoint(next_position), + *self.stack.last().unwrap(), + )) + } +} diff --git a/whir/src/sumcheck/mod.rs b/whir/src/sumcheck/mod.rs new file mode 100644 index 000000000..943a51d02 --- /dev/null +++ b/whir/src/sumcheck/mod.rs @@ -0,0 +1,321 @@ +pub mod proof; +pub mod prover_batched; +pub mod prover_core; +pub mod prover_not_skipping; +pub mod prover_not_skipping_batched; +pub mod prover_single; + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + }; + + use super::prover_core::SumcheckCore; + + type F = Field64; + + #[test] + fn test_sumcheck_folding_factor_1() { + let folding_factor = 1; + let eval_point = MultilinearPoint(vec![F::from(10), F::from(11)]); + let polynomial = + CoefficientList::new(vec![F::from(1), F::from(5), F::from(10), F::from(14)]); + + let claimed_value = polynomial.evaluate(&eval_point); + + let mut prover = SumcheckCore::new(polynomial, &[eval_point], &[F::from(1)]); + + let poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly_1.sum_over_hypercube(), claimed_value); + + let combination_randomness = F::from(100101); + let folding_randomness = MultilinearPoint(vec![F::from(4999)]); + + prover.compress(folding_factor, combination_randomness, &folding_randomness); + + let poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_2.sum_over_hypercube(), + combination_randomness * poly_1.evaluate_at_point(&folding_randomness) + ); + } + + #[test] + fn test_single_folding() { + let num_variables = 2; + let folding_factor = 2; + let polynomial = CoefficientList::new(vec![F::from(1), F::from(2), F::from(3), F::from(4)]); + + let ood_point = MultilinearPoint::expand_from_univariate(F::from(2), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(3), num_variables); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + + let epsilon_1 = F::from(10); + let epsilon_2 = F::from(100); + + let prover = SumcheckCore::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + ); + + let poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_1.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + let folding_randomness = MultilinearPoint(vec![F::from(400000), F::from(800000)]); + + let poly_eval = polynomial.evaluate(&folding_randomness); + let v_eval = epsilon_1 * eq_poly_outside(&ood_point, &folding_randomness) + + epsilon_2 * eq_poly_outside(&statement_point, &folding_randomness); + + assert_eq!( + poly_1.evaluate_at_point(&folding_randomness), + poly_eval * v_eval + ); + } + + #[test] + fn test_sumcheck_folding_factor_2() { + let num_variables = 6; + let folding_factor = 2; + let eval_point = MultilinearPoint(vec![F::from(97); num_variables]); + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + let claimed_value = polynomial.evaluate(&eval_point); + + let mut prover = SumcheckCore::new(polynomial.clone(), &[eval_point], &[F::from(1)]); + + let poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly_1.sum_over_hypercube(), claimed_value); + + let combination_randomness = [F::from(293), F::from(42)]; + let folding_randomness = MultilinearPoint(vec![F::from(335), F::from(222)]); + + let new_eval_point = MultilinearPoint(vec![F::from(32); num_variables - folding_factor]); + let folded_polynomial = polynomial.fold(&folding_randomness); + let new_fold_eval = folded_polynomial.evaluate(&new_eval_point); + + prover.compress( + folding_factor, + combination_randomness[0], + &folding_randomness, + ); + prover.add_new_equality(&[new_eval_point], &combination_randomness[1..]); + + let poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_2.sum_over_hypercube(), + combination_randomness[0] * poly_1.evaluate_at_point(&folding_randomness) + + combination_randomness[1] * new_fold_eval + ); + + let combination_randomness = F::from(23212); + prover.compress(folding_factor, combination_randomness, &folding_randomness); + + let poly_3 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_3.sum_over_hypercube(), + combination_randomness * poly_2.evaluate_at_point(&folding_randomness) + ) + } + + #[test] + fn test_e2e() { + let num_variables = 4; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + let folding_randomness_1 = MultilinearPoint(vec![F::from(11), F::from(31)]); + let fold_point = MultilinearPoint(vec![F::from(31), F::from(15)]); + let combination_randomness = [F::from(31), F::from(4999)]; + let folding_randomness_2 = MultilinearPoint(vec![F::from(97), F::from(36)]); + + let mut prover = SumcheckCore::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + ); + + let sumcheck_poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + let folded_poly_1 = polynomial.fold(&folding_randomness_1.clone()); + prover.compress( + folding_factor, + combination_randomness[0], + &folding_randomness_1, + ); + prover.add_new_equality(&[fold_point.clone()], &combination_randomness[1..]); + + let sumcheck_poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + + assert_eq!( + sumcheck_poly_1.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + let fold_answer = folded_poly_1.evaluate(&fold_point); + + assert_eq!( + sumcheck_poly_2.sum_over_hypercube(), + combination_randomness[0] * sumcheck_poly_1.evaluate_at_point(&folding_randomness_1) + + combination_randomness[1] * fold_answer + ); + + let full_folding = + MultilinearPoint([folding_randomness_2.0.clone(), folding_randomness_1.0].concat()); + let eval_coeff = folded_poly_1.fold(&folding_randomness_2).coeffs()[0]; + assert_eq!( + sumcheck_poly_2.evaluate_at_point(&folding_randomness_2), + eval_coeff + * (combination_randomness[0] + * (epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + + combination_randomness[1] + * eq_poly_outside(&folding_randomness_2, &fold_point)) + ) + } + + #[test] + fn test_e2e_larger() { + let num_variables = 6; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + let folding_randomness_1 = MultilinearPoint(vec![F::from(11), F::from(31)]); + let folding_randomness_2 = MultilinearPoint(vec![F::from(97), F::from(36)]); + let folding_randomness_3 = MultilinearPoint(vec![F::from(11297), F::from(42136)]); + let fold_point_11 = + MultilinearPoint(vec![F::from(31), F::from(15), F::from(31), F::from(15)]); + let fold_point_12 = + MultilinearPoint(vec![F::from(1231), F::from(15), F::from(4231), F::from(15)]); + let fold_point_2 = MultilinearPoint(vec![F::from(311), F::from(115)]); + let combination_randomness_1 = [F::from(1289), F::from(3281), F::from(10921)]; + let combination_randomness_2 = [F::from(3281), F::from(3232)]; + + let mut prover = SumcheckCore::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + ); + + let sumcheck_poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + let folded_poly_1 = polynomial.fold(&folding_randomness_1.clone()); + prover.compress( + folding_factor, + combination_randomness_1[0], + &folding_randomness_1, + ); + prover.add_new_equality( + &[fold_point_11.clone(), fold_point_12.clone()], + &combination_randomness_1[1..], + ); + + let sumcheck_poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + let folded_poly_2 = folded_poly_1.fold(&folding_randomness_2.clone()); + prover.compress( + folding_factor, + combination_randomness_2[0], + &folding_randomness_2, + ); + prover.add_new_equality(&[fold_point_2.clone()], &combination_randomness_2[1..]); + + let sumcheck_poly_3 = prover.compute_sumcheck_polynomial(folding_factor); + let final_coeff = folded_poly_2.fold(&folding_randomness_3.clone()).coeffs()[0]; + + // Compute all evaluations + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + let fold_answer_11 = folded_poly_1.evaluate(&fold_point_11); + let fold_answer_12 = folded_poly_1.evaluate(&fold_point_12); + let fold_answer_2 = folded_poly_2.evaluate(&fold_point_2); + + assert_eq!( + sumcheck_poly_1.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + assert_eq!( + sumcheck_poly_2.sum_over_hypercube(), + combination_randomness_1[0] * sumcheck_poly_1.evaluate_at_point(&folding_randomness_1) + + combination_randomness_1[1] * fold_answer_11 + + combination_randomness_1[2] * fold_answer_12 + ); + + assert_eq!( + sumcheck_poly_3.sum_over_hypercube(), + combination_randomness_2[0] * sumcheck_poly_2.evaluate_at_point(&folding_randomness_2) + + combination_randomness_2[1] * fold_answer_2 + ); + + let full_folding = MultilinearPoint( + [ + folding_randomness_3.0.clone(), + folding_randomness_2.0.clone(), + folding_randomness_1.0, + ] + .concat(), + ); + + assert_eq!( + sumcheck_poly_3.evaluate_at_point(&folding_randomness_3), + final_coeff + * (combination_randomness_2[0] + * (combination_randomness_1[0] + * (epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + + combination_randomness_1[1] + * eq_poly_outside( + &fold_point_11, + &MultilinearPoint( + [ + folding_randomness_3.0.clone(), + folding_randomness_2.0.clone() + ] + .concat() + ) + ) + + combination_randomness_1[2] + * eq_poly_outside( + &fold_point_12, + &MultilinearPoint( + [folding_randomness_3.0.clone(), folding_randomness_2.0] + .concat() + ) + )) + + combination_randomness_2[1] + * eq_poly_outside(&folding_randomness_3, &fold_point_2)) + ) + } +} diff --git a/whir/src/sumcheck/proof.rs b/whir/src/sumcheck/proof.rs new file mode 100644 index 000000000..b3e486403 --- /dev/null +++ b/whir/src/sumcheck/proof.rs @@ -0,0 +1,101 @@ +use ark_ff::Field; + +use crate::{ + poly_utils::{MultilinearPoint, eq_poly3}, + utils::base_decomposition, +}; + +// Stored in evaluation form +#[derive(Debug, Clone)] +pub struct SumcheckPolynomial { + n_variables: usize, // number of variables; + // evaluations has length 3^{n_variables} + // The order in which it is stored is such that evaluations[i] + // corresponds to the evaluation at utils::base_decomposition(i, 3, n_variables), + // which performs (big-endian) ternary decomposition. + // (in other words, the ordering is lexicographic wrt the evaluation point) + evaluations: Vec, /* Each of our polynomials will be in F^{<3}[X_1, \dots, X_k], + * so it us uniquely determined by it's evaluations over {0, 1, 2}^k */ +} + +impl SumcheckPolynomial +where + F: Field, +{ + pub fn new(evaluations: Vec, n_variables: usize) -> Self { + SumcheckPolynomial { + evaluations, + n_variables, + } + } + + /// Returns the vector of evaluations at {0,1,2}^n_variables of the polynomial f + /// in the following order: [f(0,0,..,0), f(0,0,..,1), f(0,0,...,2), f(0,0,...,1,0), ...] + /// (i.e. lexicographic wrt. to the evaluation points. + pub fn evaluations(&self) -> &[F] { + &self.evaluations + } + + // TODO(Gotti): Rename to sum_over_binary_hypercube for clarity? + // TODO(Gotti): Make more efficient; the base_decomposition and filtering is unneccessary. + + /// Returns the sum of evaluations of f, when summed only over {0,1}^n_variables + /// + /// (and not over {0,1,2}^n_variable) + pub fn sum_over_hypercube(&self) -> F { + let num_evaluation_points = 3_usize.pow(self.n_variables as u32); + + let mut sum = F::ZERO; + for point in 0..num_evaluation_points { + if base_decomposition(point, 3, self.n_variables) + .into_iter() + .all(|v| matches!(v, 0 | 1)) + { + sum += self.evaluations[point]; + } + } + + sum + } + + /// evaluates the polynomial at an arbitrary point, not neccessarily in {0,1,2}^n_variables. + /// + /// We assert that point.n_variables() == self.n_variables + pub fn evaluate_at_point(&self, point: &MultilinearPoint) -> F { + assert!(point.n_variables() == self.n_variables); + let num_evaluation_points = 3_usize.pow(self.n_variables as u32); + + let mut evaluation = F::ZERO; + + for index in 0..num_evaluation_points { + evaluation += self.evaluations[index] * eq_poly3(point, index); + } + + evaluation + } +} + +#[cfg(test)] +mod tests { + use crate::{crypto::fields::Field64, poly_utils::MultilinearPoint, utils::base_decomposition}; + + use super::SumcheckPolynomial; + + type F = Field64; + + #[test] + fn test_evaluation() { + let num_variables = 2; + + let num_evaluation_points = 3_usize.pow(num_variables as u32); + let evaluations = (0..num_evaluation_points as u64).map(F::from).collect(); + + let poly = SumcheckPolynomial::new(evaluations, num_variables); + + for i in 0..num_evaluation_points { + let decomp = base_decomposition(i, 3, num_variables); + let point = MultilinearPoint(decomp.into_iter().map(F::from).collect()); + assert_eq!(poly.evaluate_at_point(&point), poly.evaluations()[i]); + } + } +} diff --git a/whir/src/sumcheck/prover_batched.rs b/whir/src/sumcheck/prover_batched.rs new file mode 100644 index 000000000..49ee27bf3 --- /dev/null +++ b/whir/src/sumcheck/prover_batched.rs @@ -0,0 +1,307 @@ +use super::proof::SumcheckPolynomial; +use crate::{ + poly_utils::{MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList}, + sumcheck::prover_single::SumcheckSingle, +}; +use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::{join, prelude::*}; + +pub struct SumcheckBatched { + // The evaluation on each p and eq + evaluations_of_p: Vec>, + evaluations_of_equality: Vec>, + comb_coeff: Vec, + num_polys: usize, + num_variables: usize, + sum: F, +} + +impl SumcheckBatched +where + F: Field, +{ + // Input includes the following: + // coeffs: coefficient of a list of polynomials p + // points: one point per poly + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p0(..) * eq0(..) + p1(..) * eq1(..) + ... + pub fn new( + coeffs: Vec>, + points: &[MultilinearPoint], + poly_comb_coeff: &[F], // random coefficients for combining each poly + evals: &[F], + ) -> Self { + let num_polys = coeffs.len(); + assert_eq!(poly_comb_coeff.len(), num_polys); + assert_eq!(points.len(), num_polys); + assert_eq!(evals.len(), num_polys); + let num_variables = coeffs[0].num_variables(); + + let mut prover = SumcheckBatched { + evaluations_of_p: coeffs.into_iter().map(|c| c.into()).collect(), + evaluations_of_equality: vec![ + EvaluationsList::new(vec![F::ZERO; 1 << num_variables]); + num_polys + ], + comb_coeff: poly_comb_coeff.to_vec(), + num_polys, + num_variables, + sum: F::ZERO, + }; + + // Eval points + for (i, point) in points.iter().enumerate() { + SumcheckSingle::eval_eq( + &point.0, + prover.evaluations_of_equality[i].evals_mut(), + F::from(1), + ); + prover.sum += poly_comb_coeff[i] * evals[i]; + } + prover + } + + pub fn get_folded_polys(&self) -> Vec { + self.evaluations_of_p + .iter() + .map(|e| { + assert_eq!(e.num_variables(), 0); + e.evals()[0] + }) + .collect() + } + + pub fn get_folded_eqs(&self) -> Vec { + self.evaluations_of_equality + .iter() + .map(|e| { + assert_eq!(e.num_variables(), 0); + e.evals()[0] + }) + .collect() + } + + #[cfg(not(feature = "parallel"))] + pub fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + panic!("Non-parallel version not supported!"); + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let eval_p_iter = self.evaluation_of_p.evals().chunks_exact(2); + let eval_eq_iter = self.evaluation_of_equality.evals().chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce(|(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) + .unwrap_or((F::ZERO, F::ZERO)); + + // Use the fact that self.sum = p(0) + p(1) = 2 * c0 + c1 + c2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + #[cfg(feature = "parallel")] + pub fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let (_, c0, c2) = self + .comb_coeff + .par_iter() + .zip(&self.evaluations_of_p) + .zip(&self.evaluations_of_equality) + .map(|((rand, eval_p), eval_eq)| { + let eval_p_iter = eval_p.evals().par_chunks_exact(2); + let eval_eq_iter = eval_eq.evals().par_chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce( + || (F::ZERO, F::ZERO), + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ); + (*rand, c0, c2) + }) + .reduce( + || (F::ONE, F::ZERO, F::ZERO), + |(r0, a0, a2), (r1, b0, b2)| (F::ONE, r0 * a0 + r1 * b0, r0 * a2 + r1 * b2), + ); + + // Use the fact that self.sum = p(0) + p(1) = 2 * coeff_0 + coeff_1 + coeff_2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + // When the folding randomness arrives, compress the table accordingly (adding the new points) + #[cfg(not(feature = "parallel"))] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + panic!("Non-parallel version not supported!"); + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let evaluations_of_p = self + .evaluation_of_p + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + let evaluations_of_eq = self + .evaluation_of_equality + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + + // Update + self.num_variables -= 1; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } + + #[cfg(feature = "parallel")] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let evaluations: Vec<_> = self + .evaluations_of_p + .par_iter() + .zip(&self.evaluations_of_equality) + .map(|(eval_p, eval_eq)| { + let (evaluation_of_p, evaluation_of_eq) = join( + || { + eval_p + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + || { + eval_eq + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + ); + ( + EvaluationsList::new(evaluation_of_p), + EvaluationsList::new(evaluation_of_eq), + ) + }) + .collect(); + let (evaluations_of_p, evaluations_of_eq) = evaluations.into_iter().unzip(); + + // Update + self.num_variables -= 1; + self.evaluations_of_p = evaluations_of_p; + self.evaluations_of_equality = evaluations_of_eq; + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + }; + + use super::SumcheckBatched; + + type F = Field64; + + #[test] + fn test_sumcheck_folding_factor_1() { + let num_rounds = 2; + let eval_points = vec![ + MultilinearPoint(vec![F::from(10), F::from(11)]), + MultilinearPoint(vec![F::from(7), F::from(8)]), + ]; + let polynomials = vec![ + CoefficientList::new(vec![F::from(1), F::from(5), F::from(10), F::from(14)]), + CoefficientList::new(vec![F::from(2), F::from(6), F::from(11), F::from(13)]), + ]; + let poly_comb_coeffs = vec![F::from(2), F::from(3)]; + + let evals: Vec = polynomials + .iter() + .zip(&eval_points) + .map(|(poly, point)| poly.evaluate(point)) + .collect(); + let mut claimed_value: F = evals + .iter() + .zip(&poly_comb_coeffs) + .fold(F::from(0), |sum, (eval, poly_rand)| eval * poly_rand + sum); + + let mut prover = + SumcheckBatched::new(polynomials.clone(), &eval_points, &poly_comb_coeffs, &evals); + let mut comb_randomness_list = Vec::new(); + let mut fold_randomness_list = Vec::new(); + + for _ in 0..num_rounds { + let poly = prover.compute_sumcheck_polynomial(); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly.sum_over_hypercube(), claimed_value); + + let next_comb_randomness = F::from(100101); + let next_fold_randomness = MultilinearPoint(vec![F::from(4999)]); + + prover.compress(next_comb_randomness, &next_fold_randomness, &poly); + claimed_value = next_comb_randomness * poly.evaluate_at_point(&next_fold_randomness); + + comb_randomness_list.push(next_comb_randomness); + fold_randomness_list.extend(next_fold_randomness.0); + } + println!("CLAIM:"); + for poly in prover.evaluations_of_p { + println!("POLY: {:?}", poly); + } + println!("EXPECTED:"); + let fold_randomness_list = MultilinearPoint(fold_randomness_list); + for poly in polynomials { + println!("EVAL: {:?}", poly.evaluate(&fold_randomness_list)); + } + } +} diff --git a/whir/src/sumcheck/prover_core.rs b/whir/src/sumcheck/prover_core.rs new file mode 100644 index 000000000..9ecdfeeea --- /dev/null +++ b/whir/src/sumcheck/prover_core.rs @@ -0,0 +1,151 @@ +use ark_ff::Field; + +use crate::{ + poly_utils::{ + MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList, + sequential_lag_poly::LagrangePolynomialIterator, + }, + utils::base_decomposition, +}; + +use super::proof::SumcheckPolynomial; + +pub struct SumcheckCore { + // The evaluation of p + evaluation_of_p: EvaluationsList, + evaluation_of_equality: EvaluationsList, + num_variables: usize, +} + +impl SumcheckCore +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: CoefficientList, // multilinear polynomial in n variables + points: &[MultilinearPoint], // list of points, each of length n. + combination_randomness: &[F], + ) -> Self { + assert_eq!(points.len(), combination_randomness.len()); + let num_variables = coeffs.num_variables(); + + let mut prover = SumcheckCore { + evaluation_of_p: coeffs.into(), // transform coefficient form -> evaluation form + evaluation_of_equality: EvaluationsList::new(vec![F::ZERO; 1 << num_variables]), + num_variables, + }; + + prover.add_new_equality(points, combination_randomness); + prover + } + + pub fn compute_sumcheck_polynomial(&self, folding_factor: usize) -> SumcheckPolynomial { + let two = F::ONE + F::ONE; // Enlightening + + assert!(self.num_variables >= folding_factor); + + let num_evaluation_points = 3_usize.pow(folding_factor as u32); + let suffix_len = 1 << folding_factor; + let prefix_len = (1 << self.num_variables) / suffix_len; + + // sets evaluation_points to the set of all {0,1,2}^folding_factor + let evaluation_points: Vec<_> = (0..num_evaluation_points) + .map(|point| { + MultilinearPoint( + base_decomposition(point, 3, folding_factor) + .into_iter() + .map(|v| match v { + 0 => F::ZERO, + 1 => F::ONE, + 2 => two, + _ => unreachable!(), + }) + .collect(), + ) + }) + .collect(); + let mut evaluations = vec![F::ZERO; num_evaluation_points]; + + // NOTE: This can probably be optimised a fair bit, there are a bunch of lagranges that can + // be computed at the same time, some allocations to save ecc ecc. + for beta_prefix in 0..prefix_len { + // Gather the evaluations that we are concerned about + let indexes: Vec<_> = (0..suffix_len) + .map(|beta_suffix| suffix_len * beta_prefix + beta_suffix) + .collect(); + let left_poly = + EvaluationsList::new(indexes.iter().map(|&i| self.evaluation_of_p[i]).collect()); + let right_poly = EvaluationsList::new( + indexes + .iter() + .map(|&i| self.evaluation_of_equality[i]) + .collect(), + ); + + // For each evaluation point, update with the right added + for point in 0..num_evaluation_points { + evaluations[point] += left_poly.evaluate(&evaluation_points[point]) + * right_poly.evaluate(&evaluation_points[point]); + } + } + + SumcheckPolynomial::new(evaluations, folding_factor) + } + + pub fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + combination_randomness: &[F], + ) { + assert_eq!(combination_randomness.len(), points.len()); + for (point, rand) in points.iter().zip(combination_randomness) { + for (prefix, lag) in LagrangePolynomialIterator::new(point) { + self.evaluation_of_equality.evals_mut()[prefix.0] += *rand * lag; + } + } + } + + // When the folding randomness arrives, compress the table accordingly (adding the new points) + pub fn compress( + &mut self, + folding_factor: usize, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + ) { + assert_eq!(folding_randomness.n_variables(), folding_factor); + assert!(self.num_variables >= folding_factor); + + let suffix_len = 1 << folding_factor; + let prefix_len = (1 << self.num_variables) / suffix_len; + let mut evaluations_of_p = Vec::with_capacity(prefix_len); + let mut evaluations_of_eq = Vec::with_capacity(prefix_len); + + // Compress the table + for beta_prefix in 0..prefix_len { + let indexes: Vec<_> = (0..suffix_len) + .map(|beta_suffix| suffix_len * beta_prefix + beta_suffix) + .collect(); + + let left_poly = + EvaluationsList::new(indexes.iter().map(|&i| self.evaluation_of_p[i]).collect()); + let right_poly = EvaluationsList::new( + indexes + .iter() + .map(|&i| self.evaluation_of_equality[i]) + .collect(), + ); + + evaluations_of_p.push(left_poly.evaluate(folding_randomness)); + evaluations_of_eq + .push(combination_randomness * right_poly.evaluate(folding_randomness)); + } + + // Update + self.num_variables -= folding_factor; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + } +} diff --git a/whir/src/sumcheck/prover_not_skipping.rs b/whir/src/sumcheck/prover_not_skipping.rs new file mode 100644 index 000000000..602e9ab1c --- /dev/null +++ b/whir/src/sumcheck/prover_not_skipping.rs @@ -0,0 +1,329 @@ +use ark_ff::Field; +use nimue::{ + ProofResult, + plugins::ark::{FieldChallenges, FieldIOPattern, FieldWriter}, +}; +use nimue_pow::{PoWChallenge, PowStrategy}; + +use crate::{ + fs_utils::WhirPoWIOPattern, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, +}; + +use super::prover_single::SumcheckSingle; + +pub trait SumcheckNotSkippingIOPattern { + fn add_sumcheck(self, folding_factor: usize, pow_bits: f64) -> Self; +} + +impl SumcheckNotSkippingIOPattern for IOPattern +where + F: Field, + IOPattern: FieldIOPattern + WhirPoWIOPattern, +{ + fn add_sumcheck(mut self, folding_factor: usize, pow_bits: f64) -> Self { + for _ in 0..folding_factor { + self = self + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .pow(pow_bits); + } + self + } +} + +pub struct SumcheckProverNotSkipping { + sumcheck_prover: SumcheckSingle, +} + +impl SumcheckProverNotSkipping +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: CoefficientList, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) -> Self { + Self { + sumcheck_prover: SumcheckSingle::new( + coeffs, + points, + combination_randomness, + evaluations, + ), + } + } + + pub fn compute_sumcheck_polynomials( + &mut self, + merlin: &mut Merlin, + folding_factor: usize, + pow_bits: f64, + ) -> ProofResult> + where + S: PowStrategy, + Merlin: FieldChallenges + FieldWriter + PoWChallenge, + { + let mut res = Vec::with_capacity(folding_factor); + + for _ in 0..folding_factor { + let sumcheck_poly = self.sumcheck_prover.compute_sumcheck_polynomial(); + merlin.add_scalars(sumcheck_poly.evaluations())?; + let [folding_randomness]: [F; 1] = merlin.challenge_scalars()?; + res.push(folding_randomness); + + // Do PoW if needed + if pow_bits > 0. { + merlin.challenge_pow::(pow_bits)?; + } + + self.sumcheck_prover + .compress(F::ONE, &folding_randomness.into(), &sumcheck_poly); + } + + res.reverse(); + Ok(MultilinearPoint(res)) + } + + pub fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) { + self.sumcheck_prover + .add_new_equality(points, combination_randomness, evaluations) + } +} + +#[cfg(test)] +mod tests { + use ark_ff::Field; + use nimue::{ + IOPattern, Merlin, ProofResult, + plugins::ark::{FieldChallenges, FieldIOPattern, FieldReader}, + }; + use nimue_pow::blake3::Blake3PoW; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + sumcheck::{proof::SumcheckPolynomial, prover_not_skipping::SumcheckProverNotSkipping}, + }; + + type F = Field64; + + #[test] + fn test_e2e_short() -> ProofResult<()> { + let num_variables = 2; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + + fn add_sumcheck_io_pattern() -> IOPattern + where + F: Field, + IOPattern: FieldIOPattern, + { + IOPattern::new("test") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + } + + let iopattern = add_sumcheck_io_pattern::(); + + // Prover part + let mut merlin = iopattern.to_merlin(); + let mut prover = SumcheckProverNotSkipping::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + &[ + polynomial.evaluate_at_extension(&ood_point), + polynomial.evaluate_at_extension(&statement_point), + ], + ); + + let folding_randomness_1 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + // Compute the answers + let folded_poly_1 = polynomial.fold(&folding_randomness_1); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + + // Verifier part + let mut arthur = iopattern.to_arthur(merlin.transcript()); + let sumcheck_poly_11: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_11 = SumcheckPolynomial::new(sumcheck_poly_11.to_vec(), 1); + let [folding_randomness_11]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_12: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_12 = SumcheckPolynomial::new(sumcheck_poly_12.to_vec(), 1); + let [folding_randomness_12]: [F; 1] = arthur.challenge_scalars()?; + + assert_eq!( + sumcheck_poly_11.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + assert_eq!( + sumcheck_poly_12.sum_over_hypercube(), + sumcheck_poly_11.evaluate_at_point(&folding_randomness_11.into()) + ); + + let full_folding = MultilinearPoint(vec![folding_randomness_12, folding_randomness_11]); + + let eval_coeff = folded_poly_1.coeffs()[0]; + assert_eq!( + sumcheck_poly_12.evaluate_at_point(&folding_randomness_12.into()), + eval_coeff + * (epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + ); + + Ok(()) + } + + #[test] + fn test_e2e() -> ProofResult<()> { + let num_variables = 4; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + let fold_point = MultilinearPoint(vec![F::from(31), F::from(15)]); + let combination_randomness = vec![F::from(1000)]; + + fn add_sumcheck_io_pattern() -> IOPattern + where + F: Field, + IOPattern: FieldIOPattern, + { + IOPattern::new("test") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + } + + let iopattern = add_sumcheck_io_pattern::(); + + // Prover part + let mut merlin = iopattern.to_merlin(); + let mut prover = SumcheckProverNotSkipping::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + &[ + polynomial.evaluate_at_extension(&ood_point), + polynomial.evaluate_at_extension(&statement_point), + ], + ); + + let folding_randomness_1 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + let folded_poly_1 = polynomial.fold(&folding_randomness_1); + let fold_eval = folded_poly_1.evaluate_at_extension(&fold_point); + prover.add_new_equality(&[fold_point.clone()], &combination_randomness, &[fold_eval]); + + let folding_randomness_2 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + // Compute the answers + let folded_poly_1 = polynomial.fold(&folding_randomness_1); + let folded_poly_2 = folded_poly_1.fold(&folding_randomness_2); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + let fold_answer = folded_poly_1.evaluate(&fold_point); + + // Verifier part + let mut arthur = iopattern.to_arthur(merlin.transcript()); + let sumcheck_poly_11: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_11 = SumcheckPolynomial::new(sumcheck_poly_11.to_vec(), 1); + let [folding_randomness_11]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_12: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_12 = SumcheckPolynomial::new(sumcheck_poly_12.to_vec(), 1); + let [folding_randomness_12]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_21: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_21 = SumcheckPolynomial::new(sumcheck_poly_21.to_vec(), 1); + let [folding_randomness_21]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_22: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_22 = SumcheckPolynomial::new(sumcheck_poly_22.to_vec(), 1); + let [folding_randomness_22]: [F; 1] = arthur.challenge_scalars()?; + + assert_eq!( + sumcheck_poly_11.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + assert_eq!( + sumcheck_poly_12.sum_over_hypercube(), + sumcheck_poly_11.evaluate_at_point(&folding_randomness_11.into()) + ); + + assert_eq!( + sumcheck_poly_21.sum_over_hypercube(), + sumcheck_poly_12.evaluate_at_point(&folding_randomness_12.into()) + + combination_randomness[0] * fold_answer + ); + + assert_eq!( + sumcheck_poly_22.sum_over_hypercube(), + sumcheck_poly_21.evaluate_at_point(&folding_randomness_21.into()) + ); + + let full_folding = MultilinearPoint(vec![ + folding_randomness_22, + folding_randomness_21, + folding_randomness_12, + folding_randomness_11, + ]); + + let partial_folding = MultilinearPoint(vec![folding_randomness_22, folding_randomness_21]); + + let eval_coeff = folded_poly_2.coeffs()[0]; + assert_eq!( + sumcheck_poly_22.evaluate_at_point(&folding_randomness_22.into()), + eval_coeff + * ((epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + + combination_randomness[0] * eq_poly_outside(&partial_folding, &fold_point)) + ); + + Ok(()) + } +} diff --git a/whir/src/sumcheck/prover_not_skipping_batched.rs b/whir/src/sumcheck/prover_not_skipping_batched.rs new file mode 100644 index 000000000..5c5e85770 --- /dev/null +++ b/whir/src/sumcheck/prover_not_skipping_batched.rs @@ -0,0 +1,186 @@ +use ark_ff::Field; +use nimue::{ + ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::{PoWChallenge, PowStrategy}; + +use crate::poly_utils::{MultilinearPoint, coeffs::CoefficientList}; + +use super::prover_batched::SumcheckBatched; + +pub struct SumcheckProverNotSkippingBatched { + sumcheck_prover: SumcheckBatched, +} + +impl SumcheckProverNotSkippingBatched +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: Vec>, + points: &[MultilinearPoint], + poly_comb_coeff: &[F], // random coefficients for combining each poly + evals: &[F], + ) -> Self { + Self { + sumcheck_prover: SumcheckBatched::new(coeffs, points, poly_comb_coeff, evals), + } + } + + pub fn get_folded_polys(&self) -> Vec { + self.sumcheck_prover.get_folded_polys() + } + + pub fn _get_folded_eqs(&self) -> Vec { + self.sumcheck_prover.get_folded_eqs() + } + + pub fn compute_sumcheck_polynomials( + &mut self, + merlin: &mut Merlin, + folding_factor: usize, + pow_bits: f64, + ) -> ProofResult> + where + S: PowStrategy, + Merlin: FieldChallenges + FieldWriter + PoWChallenge, + { + let mut res = Vec::with_capacity(folding_factor); + + for _ in 0..folding_factor { + let sumcheck_poly = self.sumcheck_prover.compute_sumcheck_polynomial(); + merlin.add_scalars(sumcheck_poly.evaluations())?; + let [folding_randomness]: [F; 1] = merlin.challenge_scalars()?; + res.push(folding_randomness); + + // Do PoW if needed + if pow_bits > 0. { + merlin.challenge_pow::(pow_bits)?; + } + + self.sumcheck_prover + .compress(F::ONE, &folding_randomness.into(), &sumcheck_poly); + } + + res.reverse(); + Ok(MultilinearPoint(res)) + } +} + +#[cfg(test)] +mod tests { + use ark_ff::Field; + use nimue::{ + IOPattern, Merlin, ProofResult, + plugins::ark::{FieldChallenges, FieldIOPattern, FieldReader}, + }; + use nimue_pow::blake3::Blake3PoW; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + sumcheck::{ + proof::SumcheckPolynomial, + prover_not_skipping_batched::SumcheckProverNotSkippingBatched, + }, + }; + + type F = Field64; + + #[test] + fn test_e2e_short() -> ProofResult<()> { + let num_variables = 2; + let folding_factor = 2; + let polynomials = vec![ + CoefficientList::new((0..1 << num_variables).map(F::from).collect()), + CoefficientList::new((1..(1 << num_variables) + 1).map(F::from).collect()), + ]; + + // Initial stuff + let statement_points = vec![ + MultilinearPoint::expand_from_univariate(F::from(97), num_variables), + MultilinearPoint::expand_from_univariate(F::from(75), num_variables), + ]; + + // Poly randomness + let [alpha_1, alpha_2] = [F::from(15), F::from(32)]; + + fn add_sumcheck_io_pattern() -> IOPattern + where + F: Field, + IOPattern: FieldIOPattern, + { + IOPattern::new("test") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + } + + let iopattern = add_sumcheck_io_pattern::(); + + // Prover part + let mut merlin = iopattern.to_merlin(); + let mut prover = SumcheckProverNotSkippingBatched::new( + polynomials.clone(), + &statement_points, + &[alpha_1, alpha_2], + &[ + polynomials[0].evaluate_at_extension(&statement_points[0]), + polynomials[1].evaluate_at_extension(&statement_points[1]), + ], + ); + + let folding_randomness_1 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + // Compute the answers + let folded_polys_1: Vec<_> = polynomials + .iter() + .map(|poly| poly.fold(&folding_randomness_1)) + .collect(); + + let statement_answers: Vec = polynomials + .iter() + .zip(&statement_points) + .map(|(poly, point)| poly.evaluate(point)) + .collect(); + + // Verifier part + let mut arthur = iopattern.to_arthur(merlin.transcript()); + let sumcheck_poly_11: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_11 = SumcheckPolynomial::new(sumcheck_poly_11.to_vec(), 1); + let [folding_randomness_11]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_12: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_12 = SumcheckPolynomial::new(sumcheck_poly_12.to_vec(), 1); + let [folding_randomness_12]: [F; 1] = arthur.challenge_scalars()?; + + assert_eq!( + sumcheck_poly_11.sum_over_hypercube(), + alpha_1 * statement_answers[0] + alpha_2 * statement_answers[1] + ); + + assert_eq!( + sumcheck_poly_12.sum_over_hypercube(), + sumcheck_poly_11.evaluate_at_point(&folding_randomness_11.into()) + ); + + let full_folding = MultilinearPoint(vec![folding_randomness_12, folding_randomness_11]); + + let eval_coeff = [folded_polys_1[0].coeffs()[0], folded_polys_1[1].coeffs()[0]]; + assert_eq!( + sumcheck_poly_12.evaluate_at_point(&folding_randomness_12.into()), + eval_coeff[0] * alpha_1 * eq_poly_outside(&full_folding, &statement_points[0]) + + eval_coeff[1] * alpha_2 * eq_poly_outside(&full_folding, &statement_points[1]) + ); + + Ok(()) + } +} diff --git a/whir/src/sumcheck/prover_single.rs b/whir/src/sumcheck/prover_single.rs new file mode 100644 index 000000000..5e04ec5ca --- /dev/null +++ b/whir/src/sumcheck/prover_single.rs @@ -0,0 +1,299 @@ +use super::proof::SumcheckPolynomial; +use crate::poly_utils::{MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList}; +use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::{join, prelude::*}; + +pub struct SumcheckSingle { + // The evaluation of p + evaluation_of_p: EvaluationsList, + evaluation_of_equality: EvaluationsList, + num_variables: usize, + sum: F, +} + +impl SumcheckSingle +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: CoefficientList, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) -> Self { + assert_eq!(points.len(), combination_randomness.len()); + assert_eq!(points.len(), evaluations.len()); + let num_variables = coeffs.num_variables(); + + let mut prover = SumcheckSingle { + evaluation_of_p: coeffs.into(), + evaluation_of_equality: EvaluationsList::new(vec![F::ZERO; 1 << num_variables]), + num_variables, + sum: F::ZERO, + }; + + prover.add_new_equality(points, combination_randomness, evaluations); + prover + } + + #[cfg(not(feature = "parallel"))] + pub(crate) fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let eval_p_iter = self.evaluation_of_p.evals().chunks_exact(2); + let eval_eq_iter = self.evaluation_of_equality.evals().chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce(|(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) + .unwrap_or((F::ZERO, F::ZERO)); + + // Use the fact that self.sum = p(0) + p(1) = 2 * c0 + c1 + c2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + #[cfg(feature = "parallel")] + pub(crate) fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let eval_p_iter = self.evaluation_of_p.evals().par_chunks_exact(2); + let eval_eq_iter = self.evaluation_of_equality.evals().par_chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce( + || (F::ZERO, F::ZERO), + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ); + + // Use the fact that self.sum = p(0) + p(1) = 2 * coeff_0 + coeff_1 + coeff_2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + // Evaluate the eq function on for a given point on the hypercube, and add + // the result multiplied by the scalar to the output. + #[cfg(not(feature = "parallel"))] + pub(crate) fn eval_eq(eval: &[F], out: &mut [F], scalar: F) { + debug_assert_eq!(out.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = out.split_at_mut(out.len() / 2); + let s1 = scalar * x; + let s0 = scalar - s1; + Self::eval_eq(tail, low, s0); + Self::eval_eq(tail, high, s1); + } else { + out[0] += scalar; + } + } + + // Evaluate the eq function on a given point on the hypercube, and add + // the result multiplied by the scalar to the output. + #[cfg(feature = "parallel")] + pub(crate) fn eval_eq(eval: &[F], out: &mut [F], scalar: F) { + const PARALLEL_THRESHOLD: usize = 10; + debug_assert_eq!(out.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = out.split_at_mut(out.len() / 2); + // Update scalars using a single mul. Note that this causes a data dependency, + // so for small fields it might be better to use two muls. + // This data dependency should go away once we implement parallel point evaluation. + let s1 = scalar * x; + let s0 = scalar - s1; + if tail.len() > PARALLEL_THRESHOLD { + join( + || Self::eval_eq(tail, low, s0), + || Self::eval_eq(tail, high, s1), + ); + } else { + Self::eval_eq(tail, low, s0); + Self::eval_eq(tail, high, s1); + } + } else { + out[0] += scalar; + } + } + + pub fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) { + assert_eq!(combination_randomness.len(), points.len()); + assert_eq!(combination_randomness.len(), evaluations.len()); + for (point, rand) in points.iter().zip(combination_randomness) { + // TODO: We might want to do all points simultaneously so we + // do only a single pass over the data. + Self::eval_eq(&point.0, self.evaluation_of_equality.evals_mut(), *rand); + } + + // Update the sum + for (rand, eval) in combination_randomness.iter().zip(evaluations.iter()) { + self.sum += *rand * eval; + } + } + + // When the folding randomness arrives, compress the table accordingly (adding the new points) + #[cfg(not(feature = "parallel"))] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let evaluations_of_p = self + .evaluation_of_p + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + let evaluations_of_eq = self + .evaluation_of_equality + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + + // Update + self.num_variables -= 1; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } + + #[cfg(feature = "parallel")] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let (evaluations_of_p, evaluations_of_eq) = join( + || { + self.evaluation_of_p + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + || { + self.evaluation_of_equality + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + ); + + // Update + self.num_variables -= 1; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + }; + + use super::SumcheckSingle; + + type F = Field64; + + #[test] + fn test_sumcheck_folding_factor_1() { + let eval_point = MultilinearPoint(vec![F::from(10), F::from(11)]); + let polynomial = + CoefficientList::new(vec![F::from(1), F::from(5), F::from(10), F::from(14)]); + + let claimed_value = polynomial.evaluate(&eval_point); + + let eval = polynomial.evaluate(&eval_point); + let mut prover = SumcheckSingle::new(polynomial, &[eval_point], &[F::from(1)], &[eval]); + + let poly_1 = prover.compute_sumcheck_polynomial(); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly_1.sum_over_hypercube(), claimed_value); + + let combination_randomness = F::from(100101); + let folding_randomness = MultilinearPoint(vec![F::from(4999)]); + + prover.compress(combination_randomness, &folding_randomness, &poly_1); + + let poly_2 = prover.compute_sumcheck_polynomial(); + + assert_eq!( + poly_2.sum_over_hypercube(), + combination_randomness * poly_1.evaluate_at_point(&folding_randomness) + ); + } +} + +#[test] +fn test_eval_eq() { + use crate::{ + crypto::fields::Field64 as F, poly_utils::sequential_lag_poly::LagrangePolynomialIterator, + }; + use ark_ff::AdditiveGroup; + + let eval = vec![F::from(3), F::from(5)]; + let mut out = vec![F::ZERO; 4]; + SumcheckSingle::eval_eq(&eval, &mut out, F::ONE); + dbg!(&out); + + let point = MultilinearPoint(eval.clone()); + let mut expected = vec![F::ZERO; 4]; + for (prefix, lag) in LagrangePolynomialIterator::new(&point) { + expected[prefix.0] = lag; + } + dbg!(&expected); + + assert_eq!(&out, &expected); +} diff --git a/whir/src/utils.rs b/whir/src/utils.rs new file mode 100644 index 000000000..328553196 --- /dev/null +++ b/whir/src/utils.rs @@ -0,0 +1,152 @@ +use crate::ntt::{transpose, transpose_bench_allocate}; +use ark_ff::Field; +use std::collections::BTreeSet; + +// checks whether the given number n is a power of two. +pub fn is_power_of_two(n: usize) -> bool { + n != 0 && (n & (n - 1) == 0) +} + +/// performs big-endian binary decomposition of `value` and returns the result. +/// +/// `n_bits` must be at must usize::BITS. If it is strictly smaller, the most significant bits of `value` are ignored. +/// The returned vector v ends with the least significant bit of `value` and always has exactly `n_bits` many elements. +pub fn to_binary(value: usize, n_bits: usize) -> Vec { + // Ensure that n is within the bounds of the input integer type + assert!(n_bits <= usize::BITS as usize); + let mut result = vec![false; n_bits]; + for i in 0..n_bits { + result[n_bits - 1 - i] = (value & (1 << i)) != 0; + } + result +} + +// TODO(Gotti): n_bits is a misnomer if base > 2. Should be n_limbs or sth. +// Also, should the behaviour for value >= base^n_bits be specified as part of the API or asserted not to happen? +// Currently, we compute the decomposition of value % (base^n_bits). + +/// decomposes value into its big-endian base-ary decomposition, meaning we return a vector v, s.t. +/// +/// value = v[0]*base^(n_bits-1) + v[1] * base^(n_bits-2) + ... + v[n_bits-1] * 1, +/// where each v[i] is in 0..base. +/// The returned vector always has length exactly n_bits (we pad with leading zeros); +pub fn base_decomposition(value: usize, base: u8, n_bits: usize) -> Vec { + // Initialize the result vector with zeros of the specified length + let mut result = vec![0u8; n_bits]; + + // Create a mutable copy of the value for computation + // Note: We could just make the local passed-by-value argument `value` mutable, but this is clearer. + let mut value = value; + + // Compute the base decomposition + for i in 0..n_bits { + result[n_bits - 1 - i] = (value % (base as usize)) as u8; + value /= base as usize; + } + // TODO: Should we assert!(value == 0) here to check that the orginally passed `value` is < base^n_bits ? + + result +} + +// Gotti: Consider renaming this function. The name sounds like it's a PRG. +// TODO (Gotti): Check that ordering is actually correct at point of use (everything else is big-endian). + +/// expand_randomness outputs the vector [1, base, base^2, base^3, ...] of length len. +pub fn expand_randomness(base: F, len: usize) -> Vec { + let mut res = Vec::with_capacity(len); + let mut acc = F::ONE; + for _ in 0..len { + res.push(acc); + acc *= base; + } + + res +} + +/// Deduplicates AND orders a vector +pub fn dedup(v: impl IntoIterator) -> Vec { + Vec::from_iter(BTreeSet::from_iter(v)) +} + +// FIXME(Gotti): comment does not match what function does (due to mismatch between folding_factor and folding_factor_exp) +// Also, k should be defined: k = evals.len() / 2^{folding_factor}, I guess. + +/// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) +/// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] +pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> Vec { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose(&mut evals, folding_factor_exp, size_of_new_domain); + evals +} + +pub fn stack_evaluations_bench_allocate( + mut evals: Vec, + folding_factor: usize, +) -> Vec { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose_bench_allocate(&mut evals, folding_factor_exp, size_of_new_domain); + evals +} + +#[cfg(test)] +mod tests { + use crate::utils::base_decomposition; + + use super::{is_power_of_two, stack_evaluations, to_binary}; + + #[test] + fn test_evaluations_stack() { + use crate::crypto::fields::Field64 as F; + + let num = 256; + let folding_factor = 3; + let fold_size = 1 << folding_factor; + assert_eq!(num % fold_size, 0); + let evals: Vec<_> = (0..num as u64).map(F::from).collect(); + + let stacked = stack_evaluations(evals, folding_factor); + assert_eq!(stacked.len(), num); + + for (i, fold) in stacked.chunks_exact(fold_size).enumerate() { + assert_eq!(fold.len(), fold_size); + for (j, item) in fold.iter().copied().enumerate().take(fold_size) { + assert_eq!(item, F::from((i + j * num / fold_size) as u64)); + } + } + } + + #[test] + fn test_to_binary() { + assert_eq!(to_binary(0b10111, 5), vec![true, false, true, true, true]); + assert_eq!(to_binary(0b11001, 2), vec![false, true]); // truncate + let empty_vec: Vec = vec![]; // just for the explicit bool type. + assert_eq!(to_binary(1, 0), empty_vec); + assert_eq!(to_binary(0, 0), empty_vec); + } + + #[test] + fn test_is_power_of_two() { + assert!(!is_power_of_two(0)); + assert!(is_power_of_two(1)); + assert!(is_power_of_two(2)); + assert!(!is_power_of_two(3)); + assert!(!is_power_of_two(usize::MAX)); + } + + #[test] + fn test_base_decomposition() { + assert_eq!(base_decomposition(0b1011, 2, 6), vec![0, 0, 1, 0, 1, 1]); + assert_eq!(base_decomposition(15, 3, 3), vec![1, 2, 0]); + // check truncation: This checks the current (undocumented) behaviour (compute modulo base^number_of_limbs) works as believed. + // If we actually specify the API to have a different behaviour, this test should change. + assert_eq!(base_decomposition(15 + 81, 3, 3), vec![1, 2, 0]); + } +} diff --git a/whir/src/whir/batch/committer.rs b/whir/src/whir/batch/committer.rs new file mode 100644 index 000000000..11144ed68 --- /dev/null +++ b/whir/src/whir/batch/committer.rs @@ -0,0 +1,184 @@ +use crate::{ + ntt::expand_from_coeff, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, fold::restructure_evaluations}, + utils, + whir::committer::{Committer, Witness}, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use derive_more::Debug; +use nimue::{ + ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; + +use crate::whir::fs_utils::DigestWriter; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +#[derive(Debug, Clone)] +pub struct Witnesses +where + MerkleConfig: Config, +{ + pub(crate) polys: Vec>, + #[debug(skip)] + pub(crate) merkle_tree: MerkleTree, + pub(crate) merkle_leaves: Vec, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +impl From> for Witnesses { + fn from(witness: Witness) -> Self { + Self { + polys: vec![witness.polynomial], + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers, + } + } +} + +impl From> for Witness { + fn from(witness: Witnesses) -> Self { + Self { + polynomial: witness.polys[0].clone(), + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers, + } + } +} + +impl Committer +where + F: FftField, + MerkleConfig: Config, + PowStrategy: Sync, +{ + pub fn batch_commit( + &self, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> + where + Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, + { + let timer = start_timer!(|| "Batch Commit"); + let base_domain = self.0.starting_domain.base_domain.unwrap(); + let expansion = base_domain.size() / polys[0].num_coeffs(); + let expand_timer = start_timer!(|| "Batch Expand"); + let evals = polys + .par_iter() + .map(|poly| expand_from_coeff(poly.coeffs(), expansion)) + .collect::>>(); + end_timer!(expand_timer); + + assert_eq!(base_domain.size(), evals[0].len()); + + // These stacking operations are bottleneck of the commitment process. + // Try to finish the tasks with as few allocations as possible. + let stack_evaluations_timer = start_timer!(|| "Stack Evaluations"); + let folded_evals = evals + .into_par_iter() + .map(|evals| { + let sub_stack_evaluations_timer = start_timer!(|| "Sub Stack Evaluations"); + let ret = utils::stack_evaluations(evals, self.0.folding_factor.at_round(0)); + end_timer!(sub_stack_evaluations_timer); + ret + }) + .map(|evals| { + let restructure_evaluations_timer = start_timer!(|| "Restructure Evaluations"); + let ret = restructure_evaluations( + evals, + self.0.fold_optimisation, + base_domain.group_gen(), + base_domain.group_gen_inv(), + self.0.folding_factor.at_round(0), + ); + end_timer!(restructure_evaluations_timer); + ret + }) + .flat_map(|evals| evals.into_par_iter().map(F::from_base_prime_field)) + .collect::>(); + end_timer!(stack_evaluations_timer); + + let allocate_timer = start_timer!(|| "Allocate buffer."); + + let mut buffer = Vec::with_capacity(folded_evals.len()); + #[allow(clippy::uninit_vec)] + unsafe { + buffer.set_len(folded_evals.len()); + } + end_timer!(allocate_timer); + let horizontal_stacking_timer = start_timer!(|| "Horizontal Stacking"); + let folded_evals = super::utils::horizontal_stacking( + folded_evals, + base_domain.size(), + self.0.folding_factor.at_round(0), + buffer.as_mut_slice(), + ); + end_timer!(horizontal_stacking_timer); + + // Group folds together as a leaf. + let fold_size = 1 << self.0.folding_factor.at_round(0); + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact(fold_size * polys.len()); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals.par_chunks_exact(fold_size * polys.len()); + + let merkle_build_timer = start_timer!(|| "Build Merkle Tree"); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + end_timer!(merkle_build_timer); + + let root = merkle_tree.root(); + + merlin.add_digest(root)?; + + let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; polys.len() * self.0.committment_ood_samples]; + if self.0.committment_ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_points + .par_iter() + .zip(ood_answers.par_chunks_mut(polys.len())) + .for_each(|(ood_point, ood_answers)| { + for j in 0..polys.len() { + let eval = polys[j].evaluate_at_extension( + &MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + ), + ); + ood_answers[j] = eval; + } + }); + merlin.add_scalars(&ood_answers)?; + } + + let polys = polys + .into_par_iter() + .map(|poly| poly.clone().to_extension()) + .collect::>(); + + end_timer!(timer); + + Ok(Witnesses { + polys, + merkle_tree, + merkle_leaves: folded_evals, + ood_points, + ood_answers, + }) + } +} diff --git a/whir/src/whir/batch/iopattern.rs b/whir/src/whir/batch/iopattern.rs new file mode 100644 index 000000000..13ec986ec --- /dev/null +++ b/whir/src/whir/batch/iopattern.rs @@ -0,0 +1,124 @@ +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use nimue::plugins::ark::*; + +use crate::{ + fs_utils::{OODIOPattern, WhirPoWIOPattern}, + sumcheck::prover_not_skipping::SumcheckNotSkippingIOPattern, + whir::iopattern::DigestIOPattern, +}; + +use crate::whir::parameters::WhirConfig; + +pub trait WhirBatchIOPattern { + fn commit_batch_statement( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self; + fn add_whir_unify_proof( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self; + fn add_whir_batch_proof( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self; +} + +impl WhirBatchIOPattern for IOPattern +where + F: FftField, + MerkleConfig: Config, + IOPattern: ByteIOPattern + + FieldIOPattern + + SumcheckNotSkippingIOPattern + + WhirPoWIOPattern + + OODIOPattern + + DigestIOPattern, +{ + fn commit_batch_statement( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self { + // TODO: Add params + let mut this = self.add_digest("merkle_digest"); + if params.committment_ood_samples > 0 { + assert!(params.initial_statement); + this = this + .challenge_scalars(params.committment_ood_samples, "ood_query") + .add_scalars(params.committment_ood_samples * batch_size, "ood_ans"); + } + this + } + + fn add_whir_unify_proof( + mut self, + params: &WhirConfig, + batch_size: usize, + ) -> Self { + if batch_size > 1 { + self = self.challenge_scalars(1, "batch_poly_combination_randomness"); + } + self = self + // .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck(params.mv_parameters.num_variables, 0.); + self.add_scalars(batch_size, "unified_folded_evals") + } + + fn add_whir_batch_proof( + mut self, + params: &WhirConfig, + batch_size: usize, + ) -> Self { + if batch_size > 1 { + self = self.challenge_scalars(1, "batch_poly_combination_randomness"); + } + + // TODO: Add statement + if params.initial_statement { + self = self + .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(0), + params.starting_folding_pow_bits, + ); + } else { + self = self + .challenge_scalars(params.folding_factor.at_round(0), "folding_randomness") + .pow(params.starting_folding_pow_bits); + } + + let mut domain_size = params.starting_domain.size(); + + for (round, r) in params.round_parameters.iter().enumerate() { + let folded_domain_size = domain_size >> params.folding_factor.at_round(round); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + self = self + .add_digest("merkle_digest") + .add_ood(r.ood_samples) + .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") + .pow(r.pow_bits) + .challenge_scalars(1, "combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(round + 1), + r.folding_pow_bits, + ); + domain_size >>= 1; + } + + let folded_domain_size = domain_size + >> params + .folding_factor + .at_round(params.round_parameters.len()); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + + self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") + .challenge_bytes(domain_size_bytes * params.final_queries, "final_queries") + .pow(params.final_pow_bits) + .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) + } +} diff --git a/whir/src/whir/batch/mod.rs b/whir/src/whir/batch/mod.rs new file mode 100644 index 000000000..315e9ad92 --- /dev/null +++ b/whir/src/whir/batch/mod.rs @@ -0,0 +1,8 @@ +mod committer; +mod iopattern; +mod prover; +mod utils; +mod verifier; + +pub use committer::Witnesses; +pub use iopattern::WhirBatchIOPattern; diff --git a/whir/src/whir/batch/prover.rs b/whir/src/whir/batch/prover.rs new file mode 100644 index 000000000..04bfa2e46 --- /dev/null +++ b/whir/src/whir/batch/prover.rs @@ -0,0 +1,540 @@ +use super::committer::Witnesses; +use crate::{ + ntt::expand_from_coeff, + parameters::FoldType, + poly_utils::{ + MultilinearPoint, + coeffs::CoefficientList, + fold::{compute_fold, restructure_evaluations}, + }, + sumcheck::{ + prover_not_skipping::SumcheckProverNotSkipping, + prover_not_skipping_batched::SumcheckProverNotSkippingBatched, + }, + utils::{self, expand_randomness}, + whir::{ + WhirProof, + prover::{Prover, RoundState}, + }, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use itertools::zip_eq; +use nimue::{ + ByteChallenges, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::{self, PoWChallenge}; + +use crate::whir::fs_utils::{DigestWriter, get_challenge_stir_queries}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +struct RoundStateBatch<'a, F, MerkleConfig> +where + F: FftField, + MerkleConfig: Config, +{ + round_state: RoundState, + batching_randomness: Vec, + prev_merkle: &'a MerkleTree, + prev_merkle_answers: &'a Vec, +} + +impl Prover +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + fn validate_witnesses(&self, witness: &Witnesses) -> bool { + assert_eq!( + witness.ood_points.len() * witness.polys.len(), + witness.ood_answers.len() + ); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } + assert!(!witness.polys.is_empty(), "Input polys cannot be empty"); + witness.polys.iter().skip(1).for_each(|poly| { + assert_eq!( + poly.num_variables(), + witness.polys[0].num_variables(), + "All polys must have the same number of variables" + ); + }); + witness.polys[0].num_variables() == self.0.mv_parameters.num_variables + } + + /// batch open the same points for multiple polys + pub fn simple_batch_prove( + &self, + merlin: &mut Merlin, + points: &[MultilinearPoint], + evals_per_point: &[Vec], // outer loop on each point, inner loop on each poly + witness: &Witnesses, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + let prove_timer = start_timer!(|| "prove"); + let initial_timer = start_timer!(|| "init"); + assert!(self.0.initial_statement, "must be true for pcs"); + assert!(self.validate_parameters()); + assert!(self.validate_witnesses(witness)); + for point in points { + assert_eq!( + point.0.len(), + self.0.mv_parameters.num_variables, + "number of variables mismatch" + ); + } + let num_polys = witness.polys.len(); + for evals in evals_per_point { + assert_eq!( + evals.len(), + num_polys, + "number of polynomials not equal number of evaluations" + ); + } + + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + end_timer!(initial_timer); + + let random_coeff_timer = start_timer!(|| "random coeff"); + let random_coeff = + super::utils::generate_random_vector_batch_open(merlin, witness.polys.len())?; + end_timer!(random_coeff_timer); + + let initial_claims_timer = start_timer!(|| "initial claims"); + let initial_claims: Vec<_> = witness + .ood_points + .par_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + ) + }) + .chain(points.to_vec()) + .collect(); + end_timer!(initial_claims_timer); + + let ood_answers_timer = start_timer!(|| "ood answers"); + let ood_answers = witness + .ood_answers + .par_chunks_exact(witness.polys.len()) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + end_timer!(ood_answers_timer); + + let eval_timer = start_timer!(|| "eval"); + let eval_per_point: Vec = evals_per_point + .par_iter() + .map(|evals| compute_dot_product(evals, &random_coeff)) + .collect(); + end_timer!(eval_timer); + + let combine_timer = start_timer!(|| "Combine polynomial"); + let initial_answers: Vec<_> = ood_answers.into_iter().chain(eval_per_point).collect(); + + let polynomial = CoefficientList::combine(&witness.polys, &random_coeff); + end_timer!(combine_timer); + + let comb_timer = start_timer!(|| "combination randomness"); + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + end_timer!(comb_timer); + + let sumcheck_timer = start_timer!(|| "sumcheck"); + let mut sumcheck_prover = Some(SumcheckProverNotSkipping::new( + polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + end_timer!(sumcheck_timer); + + let sumcheck_prover_timer = start_timer!(|| "sumcheck_prover"); + let folding_randomness = sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(0), + self.0.starting_folding_pow_bits, + )?; + end_timer!(sumcheck_prover_timer); + + let timer = start_timer!(|| "round_batch"); + let round_state = RoundStateBatch { + round_state: RoundState { + domain: self.0.starting_domain.clone(), + round: 0, + sumcheck_prover, + folding_randomness, + coefficients: polynomial, + prev_merkle: MerkleTree::blank( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + 2, + ) + .unwrap(), + prev_merkle_answers: Vec::new(), + merkle_proofs: vec![], + }, + prev_merkle: &witness.merkle_tree, + prev_merkle_answers: &witness.merkle_leaves, + batching_randomness: random_coeff, + }; + + let result = self.simple_round_batch(merlin, round_state, num_polys); + end_timer!(timer); + end_timer!(prove_timer); + + result + } + + fn simple_round_batch( + &self, + merlin: &mut Merlin, + round_state: RoundStateBatch, + num_polys: usize, + ) -> ProofResult> + where + Merlin: FieldChallenges + + ByteChallenges + + FieldWriter + + ByteWriter + + PoWChallenge + + DigestWriter, + { + let batching_randomness = round_state.batching_randomness; + let prev_merkle = round_state.prev_merkle; + let prev_merkle_answers = round_state.prev_merkle_answers; + let mut round_state = round_state.round_state; + // Fold the coefficients + let folded_coefficients = round_state + .coefficients + .fold(&round_state.folding_randomness); + + let num_variables = self.0.mv_parameters.num_variables + - self.0.folding_factor.total_number(round_state.round); + + // Base case + if round_state.round == self.0.n_rounds() { + // Coefficients of the polynomial + merlin.add_scalars(folded_coefficients.coeffs())?; + + // Final verifier queries and answers + let final_challenge_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor.at_round(round_state.round), + self.0.final_queries, + merlin, + )?; + + let merkle_proof = prev_merkle + .generate_multi_proof(final_challenge_indexes.clone()) + .unwrap(); + let fold_size = 1 << self.0.folding_factor.at_round(round_state.round); + let answers = final_challenge_indexes + .into_par_iter() + .map(|i| { + prev_merkle_answers + [i * (fold_size * num_polys)..(i + 1) * (fold_size * num_polys)] + .to_vec() + }) + .collect(); + + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if self.0.final_pow_bits > 0. { + merlin.challenge_pow::(self.0.final_pow_bits)?; + } + + // Final sumcheck + if self.0.final_sumcheck_rounds > 0 { + round_state + .sumcheck_prover + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new(folded_coefficients.clone(), &[], &[], &[]) + }) + .compute_sumcheck_polynomials::( + merlin, + self.0.final_sumcheck_rounds, + self.0.final_folding_pow_bits, + )?; + } + + return Ok(WhirProof(round_state.merkle_proofs)); + } + + let round_params = &self.0.round_parameters[round_state.round]; + + // Fold the coefficients, and compute fft of polynomial (and commit) + let new_domain = round_state.domain.scale(2); + let expansion = new_domain.size() / folded_coefficients.num_coeffs(); + let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. + let folded_evals = + utils::stack_evaluations(evals, self.0.folding_factor.at_round(round_state.round + 1)); + let folded_evals = restructure_evaluations( + folded_evals, + self.0.fold_optimisation, + new_domain.backing_domain.group_gen(), + new_domain.backing_domain.group_gen_inv(), + self.0.folding_factor.at_round(round_state.round + 1), + ); + + #[cfg(not(feature = "parallel"))] + let leafs_iter = + folded_evals.chunks_exact(1 << self.0.folding_factor.at_round(round_state.round + 1)); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals + .par_chunks_exact(1 << self.0.folding_factor.at_round(round_state.round + 1)); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + + let root = merkle_tree.root(); + merlin.add_digest(root)?; + + // OOD Samples + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = Vec::with_capacity(round_params.ood_samples); + if round_params.ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers.extend(ood_points.iter().map(|ood_point| { + folded_coefficients.evaluate(&MultilinearPoint::expand_from_univariate( + *ood_point, + num_variables, + )) + })); + merlin.add_scalars(&ood_answers)?; + } + + // STIR queries + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor.at_round(round_state.round), + round_params.num_queries, + merlin, + )?; + let domain_scaled_gen = round_state + .domain + .backing_domain + .element(1 << self.0.folding_factor.at_round(round_state.round)); + let stir_challenges: Vec<_> = ood_points + .into_par_iter() + .chain( + stir_challenges_indexes + .par_iter() + .map(|i| domain_scaled_gen.pow([*i as u64])), + ) + .map(|univariate| MultilinearPoint::expand_from_univariate(univariate, num_variables)) + .collect(); + + let merkle_proof = prev_merkle + .generate_multi_proof(stir_challenges_indexes.clone()) + .unwrap(); + let fold_size = (1 << self.0.folding_factor.at_round(round_state.round)) * num_polys; + let answers = stir_challenges_indexes + .par_iter() + .map(|i| prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) + .collect::>(); + let batched_answers = answers + .par_iter() + .map(|answer| { + let chunk_size = 1 << self.0.folding_factor.at_round(round_state.round); + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += answer[i + j * chunk_size] * batching_randomness[j]; + } + } + res + }) + .collect::>(); + // Evaluate answers in the folding randomness. + let mut stir_evaluations = ood_answers.clone(); + match self.0.fold_optimisation { + FoldType::Naive => { + // See `Verifier::compute_folds_full` + let domain_size = round_state.domain.backing_domain.size(); + let domain_gen = round_state.domain.backing_domain.element(1); + let domain_gen_inv = domain_gen.inverse().unwrap(); + let coset_domain_size = 1 << self.0.folding_factor.at_round(round_state.round); + let coset_generator_inv = + domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); + stir_evaluations.extend(stir_challenges_indexes.iter().zip(&batched_answers).map( + |(index, batched_answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = domain_gen_inv.pow([*index as u64]); + + compute_fold( + batched_answers, + &round_state.folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + F::from(2).inverse().unwrap(), + self.0.folding_factor.at_round(round_state.round), + ) + }, + )) + } + FoldType::ProverHelps => { + stir_evaluations.extend(batched_answers.iter().map(|batched_answers| { + CoefficientList::new(batched_answers.to_vec()) + .evaluate(&round_state.folding_randomness) + })) + } + } + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if round_params.pow_bits > 0. { + merlin.challenge_pow::(round_params.pow_bits)?; + } + + // Randomness for combination + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, stir_challenges.len()); + + let mut sumcheck_prover = round_state + .sumcheck_prover + .take() + .map(|mut sumcheck_prover| { + sumcheck_prover.add_new_equality( + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ); + sumcheck_prover + }) + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new( + folded_coefficients.clone(), + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ) + }); + + let folding_randomness = sumcheck_prover + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(round_state.round + 1), + round_params.folding_pow_bits, + )?; + + let round_state = RoundState { + round: round_state.round + 1, + domain: new_domain, + sumcheck_prover: Some(sumcheck_prover), + folding_randomness, + coefficients: folded_coefficients, /* TODO: Is this redundant with `sumcheck_prover.coeff` ? */ + prev_merkle: merkle_tree, + prev_merkle_answers: folded_evals, + merkle_proofs: round_state.merkle_proofs, + }; + + self.round(merlin, round_state) + } +} + +impl Prover +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + /// each poly on a different point, same size + pub fn same_size_batch_prove( + &self, + merlin: &mut Merlin, + point_per_poly: &[MultilinearPoint], + eval_per_poly: &[F], + witness: &Witnesses, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + let prove_timer = start_timer!(|| "prove"); + let initial_timer = start_timer!(|| "init"); + assert!(self.0.initial_statement, "must be true for pcs"); + assert!(self.validate_parameters()); + assert!(self.validate_witnesses(witness)); + for point in point_per_poly { + assert_eq!( + point.0.len(), + self.0.mv_parameters.num_variables, + "number of variables mismatch" + ); + } + let num_polys = witness.polys.len(); + assert_eq!( + eval_per_poly.len(), + num_polys, + "number of polynomials not equal number of evaluations" + ); + end_timer!(initial_timer); + + let poly_comb_randomness_timer = start_timer!(|| "poly comb randomness"); + let poly_comb_randomness = + super::utils::generate_random_vector_batch_open(merlin, witness.polys.len())?; + end_timer!(poly_comb_randomness_timer); + + let initial_claims_timer = start_timer!(|| "initial claims"); + let initial_eval_claims = point_per_poly; + end_timer!(initial_claims_timer); + + let sumcheck_timer = start_timer!(|| "unifying sumcheck"); + let mut sumcheck_prover = SumcheckProverNotSkippingBatched::new( + witness.polys.clone(), + initial_eval_claims, + &poly_comb_randomness, + eval_per_poly, + ); + + // Perform the entire sumcheck + let folded_point = sumcheck_prover.compute_sumcheck_polynomials::( + merlin, + self.0.mv_parameters.num_variables, + 0., + )?; + let folded_evals = sumcheck_prover.get_folded_polys(); + merlin.add_scalars(&folded_evals)?; + end_timer!(sumcheck_timer); + // Problem now reduced to the polys(folded_point) =?= folded_evals + + let timer = start_timer!(|| "simple_batch"); + // perform simple_batch on folded_point and folded_evals + let result = self.simple_batch_prove(merlin, &[folded_point], &[folded_evals], witness)?; + end_timer!(timer); + end_timer!(prove_timer); + + Ok(result) + } +} diff --git a/whir/src/whir/batch/utils.rs b/whir/src/whir/batch/utils.rs new file mode 100644 index 000000000..8b1fae172 --- /dev/null +++ b/whir/src/whir/batch/utils.rs @@ -0,0 +1,101 @@ +use crate::{ + ntt::{transpose, transpose_test}, + utils::expand_randomness, + whir::fs_utils::{DigestReader, DigestWriter}, +}; +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::Field; +use ark_std::{end_timer, start_timer}; +use nimue::{ + ByteReader, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldReader, FieldWriter}, +}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub fn stack_evaluations( + mut evals: Vec, + folding_factor: usize, + buffer: &mut [F], +) -> Vec { + assert!(evals.len() % folding_factor == 0); + let size_of_new_domain = evals.len() / folding_factor; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose_test(&mut evals, folding_factor, size_of_new_domain, buffer); + evals +} + +/// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) +/// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] +/// This function will mutate the function without return +pub fn stack_evaluations_mut(evals: &mut [F], folding_factor: usize) { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose(evals, folding_factor_exp, size_of_new_domain); +} + +/// Takes a vector of matrix and stacking them horizontally +/// Use in-place matrix transposes to avoid data copy +/// each matrix has domain_size elements +/// each matrix has shape (*, 1<( + evals: Vec, + domain_size: usize, + folding_factor: usize, + buffer: &mut [F], +) -> Vec { + let fold_size = 1 << folding_factor; + let num_polys: usize = evals.len() / domain_size; + + let stack_evaluation_timer = start_timer!(|| "Stack Evaluation"); + let mut evals = stack_evaluations(evals, num_polys, buffer); + end_timer!(stack_evaluation_timer); + #[cfg(not(feature = "parallel"))] + let stacked_evals = evals.chunks_exact_mut(fold_size * num_polys); + #[cfg(feature = "parallel")] + let stacked_evals = evals.par_chunks_exact_mut(fold_size * num_polys); + let stack_evaluation_mut_timer = start_timer!(|| "Stack Evaluation Mut"); + stacked_evals.for_each(|eval| stack_evaluations_mut(eval, folding_factor)); + end_timer!(stack_evaluation_mut_timer); + evals +} + +// generate a random vector for batching open +pub fn generate_random_vector_batch_open( + merlin: &mut Merlin, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Merlin: FieldChallenges + FieldWriter + ByteWriter + DigestWriter, +{ + if size == 1 { + return Ok(vec![F::one()]); + } + let [gamma] = merlin.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} + +// generate a random vector for batching verify +pub fn generate_random_vector_batch_verify( + arthur: &mut Arthur, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Arthur: FieldChallenges + FieldReader + ByteReader + DigestReader, +{ + if size == 1 { + return Ok(vec![F::one()]); + } + let [gamma] = arthur.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} diff --git a/whir/src/whir/batch/verifier.rs b/whir/src/whir/batch/verifier.rs new file mode 100644 index 000000000..98bd2ca0e --- /dev/null +++ b/whir/src/whir/batch/verifier.rs @@ -0,0 +1,726 @@ +use std::iter; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{iterable::Iterable, log2}; +use itertools::zip_eq; +use nimue::{ + ByteChallenges, ByteReader, ProofError, ProofResult, + plugins::ark::{FieldChallenges, FieldReader}, +}; +use nimue_pow::{self, PoWChallenge}; + +use crate::{ + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + sumcheck::proof::SumcheckPolynomial, + utils::expand_randomness, + whir::{ + Statement, WhirProof, + fs_utils::{DigestReader, get_challenge_stir_queries}, + verifier::{ParsedCommitment, ParsedProof, ParsedRound, Verifier}, + }, +}; + +impl Verifier +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + // Same multiple points on each polynomial + pub fn simple_batch_verify( + &self, + arthur: &mut Arthur, + num_polys: usize, + points: &[MultilinearPoint], + evals_per_point: &[Vec], + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + for evals in evals_per_point { + assert_eq!(num_polys, evals.len()); + } + + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment_batch(arthur, num_polys)?; + self.batch_verify_internal( + arthur, + num_polys, + points, + evals_per_point, + parsed_commitment, + whir_proof, + ) + } + + // Different points on each polynomial + pub fn same_size_batch_verify( + &self, + arthur: &mut Arthur, + num_polys: usize, + point_per_poly: &[MultilinearPoint], + eval_per_poly: &[F], // evaluations of the polys on individual points + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + assert_eq!(num_polys, point_per_poly.len()); + assert_eq!(num_polys, eval_per_poly.len()); + + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment_batch(arthur, num_polys)?; + + // parse proof + let poly_comb_randomness = + super::utils::generate_random_vector_batch_verify(arthur, num_polys)?; + let (folded_points, folded_evals) = + self.parse_unify_sumcheck(arthur, point_per_poly, poly_comb_randomness)?; + + self.batch_verify_internal( + arthur, + num_polys, + &[folded_points], + &[folded_evals.clone()], + parsed_commitment, + whir_proof, + ) + } + + fn batch_verify_internal( + &self, + arthur: &mut Arthur, + num_polys: usize, + points: &[MultilinearPoint], + evals_per_point: &[Vec], + parsed_commitment: ParsedCommitment, + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + // parse proof + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + + let random_coeff = super::utils::generate_random_vector_batch_verify(arthur, num_polys)?; + let initial_claims: Vec<_> = parsed_commitment + .ood_points + .clone() + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.params.mv_parameters.num_variables, + ) + }) + .chain(points.to_vec()) + .collect(); + + let ood_answers = parsed_commitment + .ood_answers + .clone() + .chunks_exact(num_polys) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + let eval_per_point = evals_per_point + .iter() + .map(|evals| compute_dot_product(evals, &random_coeff)); + + let initial_answers: Vec<_> = ood_answers.into_iter().chain(eval_per_point).collect(); + + let statement = Statement { + points: initial_claims, + evaluations: initial_answers, + }; + let parsed = self.parse_proof_batch( + arthur, + &parsed_commitment, + &statement, + whir_proof, + random_coeff.clone(), + num_polys, + )?; + + let computed_folds = self.compute_folds(&parsed); + + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != statement + .evaluations + .clone() + .into_iter() + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() + { + return Err(ProofError::InvalidProof); + } + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); + } + + for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { + let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); + + let values = round.ood_answers.iter().copied().chain(folds.clone()); + + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + + values + .zip(&round.combination_randomness) + .map(|(val, rand)| val * rand) + .sum::(); + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + // Check the foldings computed from the proof match the evaluations of the polynomial + let final_folds = &computed_folds[computed_folds.len() - 1]; + let final_evaluations = parsed + .final_coefficients + .evaluate_at_univariate(&parsed.final_randomness_points); + if !final_folds + .iter() + .zip(final_evaluations) + .all(|(&fold, eval)| fold == eval) + { + return Err(ProofError::InvalidProof); + } + + // Check the final sumchecks + if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); + let claimed_sum = prev_sumcheck_poly_eval; + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + + // Check the final sumcheck evaluation + let evaluation_of_v_poly = self.compute_v_poly_for_batched(&statement, &parsed); + + if prev_sumcheck_poly_eval + != evaluation_of_v_poly + * parsed + .final_coefficients + .evaluate(&parsed.final_sumcheck_randomness) + { + return Err(ProofError::InvalidProof); + } + + Ok(parsed_commitment.root) + } + + fn parse_commitment_batch( + &self, + arthur: &mut Arthur, + num_polys: usize, + ) -> ProofResult> + where + Arthur: ByteReader + FieldReader + FieldChallenges + DigestReader, + { + let root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples * num_polys]; + if self.params.committment_ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + Ok(ParsedCommitment { + root, + ood_points, + ood_answers, + }) + } + + fn parse_unify_sumcheck( + &self, + arthur: &mut Arthur, + point_per_poly: &[MultilinearPoint], + poly_comb_randomness: Vec, + ) -> ProofResult<(MultilinearPoint, Vec)> + where + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, + { + let num_variables = self.params.mv_parameters.num_variables; + let mut sumcheck_rounds = Vec::new(); + + // Derive combination randomness and first sumcheck polynomial + // let [point_comb_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + // let point_comb_randomness = expand_randomness(point_comb_randomness_gen, num_points); + + // Unifying sumcheck + sumcheck_rounds.reserve_exact(num_variables); + for _ in 0..num_variables { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + let folded_point = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + let folded_eqs: Vec = point_per_poly + .iter() + .zip(&poly_comb_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(point, &folded_point)) + .collect(); + let mut folded_evals = vec![F::ZERO; point_per_poly.len()]; + arthur.fill_next_scalars(&mut folded_evals)?; + let sumcheck_claim = + sumcheck_rounds[num_variables - 1] + .0 + .evaluate_at_point(&MultilinearPoint(vec![ + sumcheck_rounds[num_variables - 1].1, + ])); + let sumcheck_expected: F = folded_evals + .iter() + .zip(&folded_eqs) + .map(|(eval, eq)| *eval * *eq) + .sum(); + if sumcheck_claim != sumcheck_expected { + return Err(ProofError::InvalidProof); + } + + Ok((folded_point, folded_evals)) + } + + fn pow_with_precomputed_squares(squares: &[F], mut index: usize) -> F { + let mut result = F::one(); + let mut i = 0; + while index > 0 { + if index & 1 == 1 { + result *= squares[i]; + } + index >>= 1; + i += 1; + } + result + } + + fn parse_proof_batch( + &self, + arthur: &mut Arthur, + parsed_commitment: &ParsedCommitment, + statement: &Statement, // Will be needed later + whir_proof: &WhirProof, + batched_randomness: Vec, + num_polys: usize, + ) -> ProofResult> + where + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, + { + let mut sumcheck_rounds = Vec::new(); + let mut folding_randomness: MultilinearPoint; + let initial_combination_randomness; + + if self.params.initial_statement { + // Derive combination randomness and first sumcheck polynomial + let [combination_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + initial_combination_randomness = expand_randomness( + combination_randomness_gen, + parsed_commitment.ood_points.len() + statement.points.len(), + ); + + // Initial sumcheck + sumcheck_rounds.reserve_exact(self.params.folding_factor.at_round(0)); + for _ in 0..self.params.folding_factor.at_round(0) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + + folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + } else { + assert_eq!(parsed_commitment.ood_points.len(), 0); + assert_eq!(statement.points.len(), 0); + + initial_combination_randomness = vec![F::ONE]; + + let mut folding_randomness_vec = vec![F::ZERO; self.params.folding_factor.at_round(0)]; + arthur.fill_challenge_scalars(&mut folding_randomness_vec)?; + folding_randomness = MultilinearPoint(folding_randomness_vec); + + // PoW + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + }; + + let mut prev_root = parsed_commitment.root.clone(); + let domain_gen = self.params.starting_domain.backing_domain.group_gen(); + // Precompute the powers of the domain generator, so that + // we can always compute domain_gen.pow(1 << i) by domain_gen_powers[i] + let domain_gen_powers = std::iter::successors(Some(domain_gen), |&curr| Some(curr * curr)) + .take(log2(self.params.starting_domain.size()) as usize) + .collect::>(); + // Since the generator of the domain will be repeatedly squared in + // the future, keep track of the log of the power (i.e., how many times + // it has been squared from domain_gen). + // In another word, always ensure current domain generator = domain_gen_powers[log_based_on_domain_gen] + let mut log_based_on_domain_gen: usize = 0; + let mut domain_gen_inv = self.params.starting_domain.backing_domain.group_gen_inv(); + let mut domain_size = self.params.starting_domain.size(); + let mut rounds = vec![]; + + for r in 0..self.params.n_rounds() { + let (merkle_proof, answers) = &whir_proof.0[r]; + let round_params = &self.params.round_parameters[r]; + + let new_root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = vec![F::ZERO; round_params.ood_samples]; + if round_params.ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + let stir_challenges_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(r), + round_params.num_queries, + arthur, + )?; + + let stir_challenges_points = stir_challenges_indexes + .iter() + .map(|index| { + Self::pow_with_precomputed_squares( + &domain_gen_powers.as_slice() + [log_based_on_domain_gen + self.params.folding_factor.at_round(r)..], + *index, + ) + }) + .collect(); + + if !merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || merkle_proof.leaf_indexes != stir_challenges_indexes + { + return Err(ProofError::InvalidProof); + } + + let answers: Vec<_> = if r == 0 { + answers + .iter() + .map(|raw_answer| { + if !batched_randomness.is_empty() { + let chunk_size = 1 << self.params.folding_factor.at_round(r); + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += + raw_answer[i + j * chunk_size] * batched_randomness[j]; + } + } + res + } else { + raw_answer.clone() + } + }) + .collect() + } else { + answers.to_vec() + }; + + if round_params.pow_bits > 0. { + arthur.challenge_pow::(round_params.pow_bits)?; + } + + let [combination_randomness_gen] = arthur.challenge_scalars()?; + let combination_randomness = expand_randomness( + combination_randomness_gen, + stir_challenges_indexes.len() + round_params.ood_samples, + ); + + let mut sumcheck_rounds = + Vec::with_capacity(self.params.folding_factor.at_round(r + 1)); + for _ in 0..self.params.folding_factor.at_round(r + 1) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if round_params.folding_pow_bits > 0. { + arthur.challenge_pow::(round_params.folding_pow_bits)?; + } + } + + let new_folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + + rounds.push(ParsedRound { + folding_randomness, + ood_points, + ood_answers, + stir_challenges_indexes, + stir_challenges_points, + stir_challenges_answers: answers, + combination_randomness, + sumcheck_rounds, + domain_gen_inv, + }); + + folding_randomness = new_folding_randomness; + + prev_root = new_root.clone(); + log_based_on_domain_gen += 1; + domain_gen_inv = domain_gen_inv * domain_gen_inv; + domain_size >>= 1; + } + + let mut final_coefficients = vec![F::ZERO; 1 << self.params.final_sumcheck_rounds]; + arthur.fill_next_scalars(&mut final_coefficients)?; + let final_coefficients = CoefficientList::new(final_coefficients); + + // Final queries verify + let final_randomness_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(self.params.n_rounds()), + self.params.final_queries, + arthur, + )?; + let final_randomness_points = final_randomness_indexes + .iter() + .map(|index| { + Self::pow_with_precomputed_squares( + &domain_gen_powers.as_slice()[log_based_on_domain_gen + + self.params.folding_factor.at_round(self.params.n_rounds())..], + *index, + ) + }) + .collect(); + + let (final_merkle_proof, final_randomness_answers) = &whir_proof.0[whir_proof.0.len() - 1]; + if !final_merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + final_randomness_answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || final_merkle_proof.leaf_indexes != final_randomness_indexes + { + return Err(ProofError::InvalidProof); + } + + let final_randomness_answers: Vec<_> = if self.params.n_rounds() == 0 { + final_randomness_answers + .iter() + .map(|raw_answer| { + if !batched_randomness.is_empty() { + let chunk_size = + 1 << self.params.folding_factor.at_round(self.params.n_rounds()); + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += raw_answer[i + j * chunk_size] * batched_randomness[j]; + } + } + res + } else { + raw_answer.clone() + } + }) + .collect() + } else { + final_randomness_answers.to_vec() + }; + + if self.params.final_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_pow_bits)?; + } + + let mut final_sumcheck_rounds = Vec::with_capacity(self.params.final_sumcheck_rounds); + for _ in 0..self.params.final_sumcheck_rounds { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + final_sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.final_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_folding_pow_bits)?; + } + } + let final_sumcheck_randomness = MultilinearPoint( + final_sumcheck_rounds + .iter() + .map(|&(_, r)| r) + .rev() + .collect(), + ); + + Ok(ParsedProof { + initial_combination_randomness, + initial_sumcheck_rounds: sumcheck_rounds, + rounds, + final_domain_gen_inv: domain_gen_inv, + final_folding_randomness: folding_randomness, + final_randomness_indexes, + final_randomness_points, + final_randomness_answers: final_randomness_answers.to_vec(), + final_sumcheck_rounds, + final_sumcheck_randomness, + final_coefficients, + }) + } + + /// this is copied and modified from `fn compute_v_poly` + /// to avoid modify the original function for compatibility + fn compute_v_poly_for_batched(&self, statement: &Statement, proof: &ParsedProof) -> F { + let mut num_variables = self.params.mv_parameters.num_variables; + + let mut folding_randomness = MultilinearPoint( + iter::once(&proof.final_sumcheck_randomness.0) + .chain(iter::once(&proof.final_folding_randomness.0)) + .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) + .flatten() + .copied() + .collect(), + ); + + let mut value = statement + .points + .iter() + .zip(&proof.initial_combination_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(point, &folding_randomness)) + .sum(); + + for (round, round_proof) in proof.rounds.iter().enumerate() { + num_variables -= self.params.folding_factor.at_round(round); + folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); + + let ood_points = &round_proof.ood_points; + let stir_challenges_points = &round_proof.stir_challenges_points; + let stir_challenges: Vec<_> = ood_points + .iter() + .chain(stir_challenges_points) + .cloned() + .map(|univariate| { + MultilinearPoint::expand_from_univariate(univariate, num_variables) + // TODO: + // Maybe refactor outside + }) + .collect(); + + let sum_of_claims: F = stir_challenges + .into_iter() + .map(|point| eq_poly_outside(&point, &folding_randomness)) + .zip(&round_proof.combination_randomness) + .map(|(point, rand)| point * rand) + .sum(); + + value += sum_of_claims; + } + + value + } +} diff --git a/whir/src/whir/committer.rs b/whir/src/whir/committer.rs new file mode 100644 index 000000000..3ca14d772 --- /dev/null +++ b/whir/src/whir/committer.rs @@ -0,0 +1,128 @@ +use super::parameters::WhirConfig; +use crate::{ + ntt::expand_from_coeff, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, fold::restructure_evaluations}, + utils, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use derive_more::Debug; +use nimue::{ + ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; + +use crate::whir::fs_utils::DigestWriter; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +#[derive(Clone, Debug)] +pub struct Witness +where + MerkleConfig: Config, +{ + pub(crate) polynomial: CoefficientList, + #[debug(skip)] + pub(crate) merkle_tree: MerkleTree, + pub(crate) merkle_leaves: Vec, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +pub struct Committer( + pub(crate) WhirConfig, +) +where + F: FftField, + MerkleConfig: Config; + +impl Committer +where + F: FftField, + MerkleConfig: Config, +{ + pub fn new(config: WhirConfig) -> Self { + Self(config) + } + + pub fn commit( + &self, + merlin: &mut Merlin, + mut polynomial: CoefficientList, + ) -> ProofResult> + where + Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, + { + let timer = start_timer!(|| "Single Commit"); + // If size of polynomial < folding factor, keep doubling polynomial size by cloning itself + polynomial.pad_to_num_vars(self.0.folding_factor.at_round(0)); + + let base_domain = self.0.starting_domain.base_domain.unwrap(); + let expansion = base_domain.size() / polynomial.num_coeffs(); + let evals = expand_from_coeff(polynomial.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. + let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor.at_round(0)); + let folded_evals = restructure_evaluations( + folded_evals, + self.0.fold_optimisation, + base_domain.group_gen(), + base_domain.group_gen_inv(), + self.0.folding_factor.at_round(0), + ); + + // Convert to extension field. + // This is not necessary for the commit, but in further rounds + // we will need the extension field. For symplicity we do it here too. + // TODO: Commit to base field directly. + let folded_evals = folded_evals + .into_iter() + .map(F::from_base_prime_field) + .collect::>(); + + // Group folds together as a leaf. + let fold_size = 1 << self.0.folding_factor.at_round(0); + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact(fold_size); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals.par_chunks_exact(fold_size); + + let merkle_build_timer = start_timer!(|| "Single Merkle Tree Build"); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + end_timer!(merkle_build_timer); + + let root = merkle_tree.root(); + + merlin.add_digest(root)?; + + let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; + let mut ood_answers = Vec::with_capacity(self.0.committment_ood_samples); + if self.0.committment_ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers.extend(ood_points.iter().map(|ood_point| { + polynomial.evaluate_at_extension(&MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + )) + })); + merlin.add_scalars(&ood_answers)?; + } + + end_timer!(timer); + + Ok(Witness { + polynomial: polynomial.to_extension(), + merkle_tree, + merkle_leaves: folded_evals, + ood_points, + ood_answers, + }) + } +} diff --git a/whir/src/whir/fs_utils.rs b/whir/src/whir/fs_utils.rs new file mode 100644 index 000000000..e9a2922e8 --- /dev/null +++ b/whir/src/whir/fs_utils.rs @@ -0,0 +1,39 @@ +use crate::utils::dedup; +use ark_crypto_primitives::merkle_tree::Config; +use nimue::{ByteChallenges, ProofResult}; + +pub fn get_challenge_stir_queries( + domain_size: usize, + folding_factor: usize, + num_queries: usize, + transcript: &mut T, +) -> ProofResult> +where + T: ByteChallenges, +{ + let folded_domain_size = domain_size / (1 << folding_factor); + // How many bytes do we need to represent an index in the folded domain? + // domain_size_bytes = log2(folded_domain_size) / 8 + // (both operations are rounded up) + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + // We need these many bytes to represent the query indices + let mut queries = vec![0u8; num_queries * domain_size_bytes]; + transcript.fill_challenge_bytes(&mut queries)?; + let indices = queries.chunks_exact(domain_size_bytes).map(|chunk| { + let mut result = 0; + for byte in chunk { + result <<= 8; + result |= *byte as usize; + } + result % folded_domain_size + }); + Ok(dedup(indices)) +} + +pub trait DigestWriter { + fn add_digest(&mut self, digest: MerkleConfig::InnerDigest) -> ProofResult<()>; +} + +pub trait DigestReader { + fn read_digest(&mut self) -> ProofResult; +} diff --git a/whir/src/whir/iopattern.rs b/whir/src/whir/iopattern.rs new file mode 100644 index 000000000..a97d5e3b6 --- /dev/null +++ b/whir/src/whir/iopattern.rs @@ -0,0 +1,95 @@ +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use nimue::plugins::ark::*; + +use crate::{ + fs_utils::{OODIOPattern, WhirPoWIOPattern}, + sumcheck::prover_not_skipping::SumcheckNotSkippingIOPattern, +}; + +use super::parameters::WhirConfig; + +pub trait DigestIOPattern { + fn add_digest(self, label: &str) -> Self; +} + +pub trait WhirIOPattern { + fn commit_statement( + self, + params: &WhirConfig, + ) -> Self; + fn add_whir_proof(self, params: &WhirConfig) + -> Self; +} + +impl WhirIOPattern for IOPattern +where + F: FftField, + MerkleConfig: Config, + IOPattern: ByteIOPattern + + FieldIOPattern + + SumcheckNotSkippingIOPattern + + WhirPoWIOPattern + + OODIOPattern + + DigestIOPattern, +{ + fn commit_statement( + self, + params: &WhirConfig, + ) -> Self { + // TODO: Add params + let mut this = self.add_digest("merkle_digest"); + if params.committment_ood_samples > 0 { + assert!(params.initial_statement); + this = this.add_ood(params.committment_ood_samples); + } + this + } + + fn add_whir_proof( + mut self, + params: &WhirConfig, + ) -> Self { + // TODO: Add statement + if params.initial_statement { + self = self + .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(0), + params.starting_folding_pow_bits, + ); + } else { + self = self + .challenge_scalars(params.folding_factor.at_round(0), "folding_randomness") + .pow(params.starting_folding_pow_bits); + } + + let mut domain_size = params.starting_domain.size(); + for (round, r) in params.round_parameters.iter().enumerate() { + let folded_domain_size = domain_size >> params.folding_factor.at_round(round); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + self = self + .add_digest("merkle_digest") + .add_ood(r.ood_samples) + .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") + .pow(r.pow_bits) + .challenge_scalars(1, "combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(round + 1), + r.folding_pow_bits, + ); + domain_size >>= 1; + } + + let folded_domain_size = domain_size + >> params + .folding_factor + .at_round(params.round_parameters.len()); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + + self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") + .challenge_bytes(domain_size_bytes * params.final_queries, "final_queries") + .pow(params.final_pow_bits) + .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) + } +} diff --git a/whir/src/whir/mod.rs b/whir/src/whir/mod.rs new file mode 100644 index 000000000..816ac3c84 --- /dev/null +++ b/whir/src/whir/mod.rs @@ -0,0 +1,369 @@ +use ark_crypto_primitives::merkle_tree::{Config, MultiPath}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +use crate::poly_utils::MultilinearPoint; + +pub mod batch; +pub mod committer; +pub mod fs_utils; +pub mod iopattern; +pub mod parameters; +pub mod prover; +pub mod verifier; + +#[derive(Debug, Clone, Default)] +pub struct Statement { + pub points: Vec>, + pub evaluations: Vec, +} + +// Only includes the authentication paths +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct WhirProof(pub(crate) Vec<(MultiPath, Vec>)>) +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize; + +pub fn whir_proof_size( + transcript: &[u8], + whir_proof: &WhirProof, +) -> usize +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + transcript.len() + whir_proof.serialized_size(ark_serialize::Compress::Yes) +} + +#[cfg(test)] +mod tests { + use nimue::{DefaultHash, IOPattern}; + use nimue_pow::blake3::Blake3PoW; + + use crate::{ + crypto::{fields::Field64, merkle_tree::blake3 as merkle_tree}, + parameters::{ + FoldType, FoldingFactor, MultivariateParameters, SoundnessType, WhirParameters, + }, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::{ + Statement, batch::WhirBatchIOPattern, committer::Committer, iopattern::WhirIOPattern, + parameters::WhirConfig, prover::Prover, verifier::Verifier, + }, + }; + + type MerkleConfig = merkle_tree::MerkleTreeParams; + type PowStrategy = Blake3PoW; + type F = Field64; + + fn make_whir_things( + num_variables: usize, + folding_factor: FoldingFactor, + num_points: usize, + soundness_type: SoundnessType, + pow_bits: usize, + fold_type: FoldType, + ) { + let num_coeffs = 1 << num_variables; + + let mut rng = ark_std::test_rng(); + let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor, + leaf_hash_params, + two_to_one_params, + soundness_type, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + fold_optimisation: fold_type, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let polynomial = CoefficientList::new(vec![F::from(1); num_coeffs]); + + let points: Vec<_> = (0..num_points) + .map(|_| MultilinearPoint::rand(&mut rng, num_variables)) + .collect(); + + let statement = Statement { + points: points.clone(), + evaluations: points + .iter() + .map(|point| polynomial.evaluate(point)) + .collect(), + }; + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, statement.clone(), witness) + .unwrap(); + + let verifier = Verifier::new(params); + let mut arthur = io.to_arthur(merlin.transcript()); + assert!(verifier.verify(&mut arthur, &statement, &proof).is_ok()); + } + + fn make_whir_batch_things_same_point( + num_polynomials: usize, + num_variables: usize, + num_points: usize, + folding_factor: usize, + soundness_type: SoundnessType, + pow_bits: usize, + fold_type: FoldType, + ) { + println!( + "NP = {num_polynomials}, NE = {num_points}, NV = {num_variables}, FOLD_TYPE = {:?}", + fold_type + ); + let num_coeffs = 1 << num_variables; + + let mut rng = ark_std::test_rng(); + let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor: FoldingFactor::Constant(folding_factor), + leaf_hash_params, + two_to_one_params, + soundness_type, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + fold_optimisation: fold_type, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let polynomials: Vec> = (0..num_polynomials) + .map(|i| CoefficientList::new(vec![F::from((i + 1) as i32); num_coeffs])) + .collect(); + + let points: Vec> = (0..num_points) + .map(|_| MultilinearPoint::rand(&mut rng, num_variables)) + .collect(); + let evals_per_point: Vec> = points + .iter() + .map(|point| { + polynomials + .iter() + .map(|poly| poly.evaluate(point)) + .collect() + }) + .collect(); + + let io = IOPattern::::new("🌪️") + .commit_batch_statement(¶ms, num_polynomials) + .add_whir_batch_proof(¶ms, num_polynomials) + .clone(); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params.clone()); + let witnesses = committer.batch_commit(&mut merlin, &polynomials).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .simple_batch_prove(&mut merlin, &points, &evals_per_point, &witnesses) + .unwrap(); + + let verifier = Verifier::new(params); + let mut arthur = io.to_arthur(merlin.transcript()); + assert!( + verifier + .simple_batch_verify( + &mut arthur, + num_polynomials, + &points, + &evals_per_point, + &proof + ) + .is_ok() + ); + println!("PASSED!"); + } + + fn make_whir_batch_things_diff_point( + num_polynomials: usize, + num_variables: usize, + folding_factor: usize, + soundness_type: SoundnessType, + pow_bits: usize, + fold_type: FoldType, + ) { + println!( + "NP = {num_polynomials}, NV = {num_variables}, FOLD_TYPE = {:?}", + fold_type + ); + let num_coeffs = 1 << num_variables; + + let mut rng = ark_std::test_rng(); + let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor: FoldingFactor::Constant(folding_factor), + leaf_hash_params, + two_to_one_params, + soundness_type, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + fold_optimisation: fold_type, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let polynomials: Vec> = (0..num_polynomials) + .map(|i| CoefficientList::new(vec![F::from((i + 1) as i32); num_coeffs])) + .collect(); + + let point_per_poly: Vec> = (0..num_polynomials) + .map(|_| MultilinearPoint::rand(&mut rng, num_variables)) + .collect(); + let eval_per_poly: Vec = polynomials + .iter() + .zip(&point_per_poly) + .map(|(poly, point)| poly.evaluate(point)) + .collect(); + + let io = IOPattern::::new("🌪️") + .commit_batch_statement(¶ms, num_polynomials) + .add_whir_unify_proof(¶ms, num_polynomials) + .add_whir_batch_proof(¶ms, num_polynomials) + .clone(); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params.clone()); + let witnesses = committer.batch_commit(&mut merlin, &polynomials).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .same_size_batch_prove(&mut merlin, &point_per_poly, &eval_per_poly, &witnesses) + .unwrap(); + + let verifier = Verifier::new(params); + let mut arthur = io.to_arthur(merlin.transcript()); + verifier + .same_size_batch_verify( + &mut arthur, + num_polynomials, + &point_per_poly, + &eval_per_poly, + &proof, + ) + .unwrap(); + // assert!(verifier + // .same_size_batch_verify(&mut arthur, num_polynomials, &point_per_poly, &eval_per_poly, &proof) + // .is_ok()); + println!("PASSED!"); + } + + #[test] + fn test_whir() { + let folding_factors = [2, 3, 4, 5]; + let soundness_type = [ + SoundnessType::ConjectureList, + SoundnessType::ProvableList, + SoundnessType::UniqueDecoding, + ]; + let fold_types = [FoldType::Naive, FoldType::ProverHelps]; + let num_points = [0, 1, 2]; + let num_polys = [1, 2, 3]; + let pow_bits = [0, 5, 10]; + + for folding_factor in folding_factors { + let num_variables = folding_factor - 1..=2 * folding_factor; + for num_variables in num_variables { + for fold_type in fold_types { + for num_points in num_points { + for soundness_type in soundness_type { + for pow_bits in pow_bits { + make_whir_things( + num_variables, + FoldingFactor::Constant(folding_factor), + num_points, + soundness_type, + pow_bits, + fold_type, + ); + } + } + } + } + } + } + + for folding_factor in folding_factors { + let num_variables = folding_factor..=2 * folding_factor; + for num_variables in num_variables { + for fold_type in fold_types { + for num_points in num_points { + for num_polys in num_polys { + for soundness_type in soundness_type { + for pow_bits in pow_bits { + make_whir_batch_things_same_point( + num_polys, + num_variables, + num_points, + folding_factor, + soundness_type, + pow_bits, + fold_type, + ); + } + } + } + } + } + } + } + + for folding_factor in folding_factors { + let num_variables = folding_factor..=2 * folding_factor; + for num_variables in num_variables { + for fold_type in fold_types { + for num_polys in num_polys { + for soundness_type in soundness_type { + for pow_bits in pow_bits { + make_whir_batch_things_diff_point( + num_polys, + num_variables, + folding_factor, + soundness_type, + pow_bits, + fold_type, + ); + } + } + } + } + } + } + } +} diff --git a/whir/src/whir/parameters.rs b/whir/src/whir/parameters.rs new file mode 100644 index 000000000..23a0bcb7b --- /dev/null +++ b/whir/src/whir/parameters.rs @@ -0,0 +1,633 @@ +use core::{fmt, panic}; +use derive_more::Debug; +use std::{f64::consts::LOG2_10, fmt::Display, marker::PhantomData}; + +use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; +use ark_ff::FftField; +use serde::{Deserialize, Serialize}; + +use crate::{ + crypto::fields::FieldWithSize, + domain::Domain, + parameters::{FoldType, FoldingFactor, MultivariateParameters, SoundnessType, WhirParameters}, +}; + +#[derive(Clone, Debug)] +pub struct WhirConfig +where + F: FftField, + MerkleConfig: Config, +{ + pub(crate) mv_parameters: MultivariateParameters, + pub(crate) soundness_type: SoundnessType, + pub(crate) security_level: usize, + pub(crate) max_pow_bits: usize, + + pub(crate) committment_ood_samples: usize, + // The WHIR protocol can prove either: + // 1. The commitment is a valid low degree polynomial. In that case, the + // initial statement is set to false. + // 2. The commitment is a valid folded polynomial, and an additional + // polynomial evaluation statement. In that case, the initial statement + // is set to true. + pub(crate) initial_statement: bool, + pub(crate) starting_domain: Domain, + pub(crate) starting_log_inv_rate: usize, + pub(crate) starting_folding_pow_bits: f64, + + pub(crate) folding_factor: FoldingFactor, + pub(crate) round_parameters: Vec, + pub(crate) fold_optimisation: FoldType, + + pub(crate) final_queries: usize, + pub(crate) final_pow_bits: f64, + pub(crate) final_log_inv_rate: usize, + pub(crate) final_sumcheck_rounds: usize, + pub(crate) final_folding_pow_bits: f64, + + // PoW parameters + pub(crate) pow_strategy: PhantomData, + + // Merkle tree parameters + #[debug(skip)] + pub(crate) leaf_hash_params: LeafParam, + #[debug(skip)] + pub(crate) two_to_one_params: TwoToOneParam, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct RoundConfig { + pub(crate) pow_bits: f64, + pub(crate) folding_pow_bits: f64, + pub(crate) num_queries: usize, + pub(crate) ood_samples: usize, + pub(crate) log_inv_rate: usize, +} + +impl WhirConfig +where + F: FftField + FieldWithSize, + MerkleConfig: Config, +{ + pub fn new( + mut mv_parameters: MultivariateParameters, + whir_parameters: WhirParameters, + ) -> Self { + // Pad the number of variables to folding factor + if mv_parameters.num_variables < whir_parameters.folding_factor.at_round(0) { + mv_parameters.num_variables = whir_parameters.folding_factor.at_round(0); + } + whir_parameters + .folding_factor + .check_validity(mv_parameters.num_variables) + .unwrap(); + + let protocol_security_level = + 0.max(whir_parameters.security_level - whir_parameters.pow_bits); + + let starting_domain = Domain::new( + 1 << mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + ) + .expect("Should have found an appropriate domain - check Field 2 adicity?"); + + let (num_rounds, final_sumcheck_rounds) = whir_parameters + .folding_factor + .compute_number_of_rounds(mv_parameters.num_variables); + + let field_size_bits = F::field_size_in_bits(); + + let committment_ood_samples = if whir_parameters.initial_statement { + Self::ood_samples( + whir_parameters.security_level, + whir_parameters.soundness_type, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + field_size_bits, + ) + } else { + 0 + }; + + let starting_folding_pow_bits = if whir_parameters.initial_statement { + Self::folding_pow_bits( + whir_parameters.security_level, + whir_parameters.soundness_type, + field_size_bits, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + ) + } else { + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + whir_parameters.soundness_type, + field_size_bits, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + ) + (whir_parameters.folding_factor.at_round(0) as f64).log2(); + 0_f64.max(whir_parameters.security_level as f64 - prox_gaps_error) + }; + + let mut round_parameters = Vec::with_capacity(num_rounds); + let mut num_variables = + mv_parameters.num_variables - whir_parameters.folding_factor.at_round(0); + let mut log_inv_rate = whir_parameters.starting_log_inv_rate; + for round in 0..num_rounds { + // Queries are set w.r.t. to old rate, while the rest to the new rate + let next_rate = log_inv_rate + (whir_parameters.folding_factor.at_round(round) - 1); + + let log_next_eta = Self::log_eta(whir_parameters.soundness_type, next_rate); + let num_queries = Self::queries( + whir_parameters.soundness_type, + protocol_security_level, + log_inv_rate, + ); + + let ood_samples = Self::ood_samples( + whir_parameters.security_level, + whir_parameters.soundness_type, + num_variables, + next_rate, + log_next_eta, + field_size_bits, + ); + + let query_error = + Self::rbr_queries(whir_parameters.soundness_type, log_inv_rate, num_queries); + let combination_error = Self::rbr_soundness_queries_combination( + whir_parameters.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_next_eta, + ood_samples, + num_queries, + ); + + let pow_bits = 0_f64 + .max(whir_parameters.security_level as f64 - (query_error.min(combination_error))); + + let folding_pow_bits = Self::folding_pow_bits( + whir_parameters.security_level, + whir_parameters.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_next_eta, + ); + + round_parameters.push(RoundConfig { + ood_samples, + num_queries, + pow_bits, + folding_pow_bits, + log_inv_rate, + }); + + num_variables -= whir_parameters.folding_factor.at_round(round + 1); + log_inv_rate = next_rate; + } + + let final_queries = Self::queries( + whir_parameters.soundness_type, + protocol_security_level, + log_inv_rate, + ); + + let final_pow_bits = 0_f64.max( + whir_parameters.security_level as f64 + - Self::rbr_queries(whir_parameters.soundness_type, log_inv_rate, final_queries), + ); + + let final_folding_pow_bits = + 0_f64.max(whir_parameters.security_level as f64 - (field_size_bits - 1) as f64); + + WhirConfig { + security_level: whir_parameters.security_level, + max_pow_bits: whir_parameters.pow_bits, + initial_statement: whir_parameters.initial_statement, + committment_ood_samples, + mv_parameters, + starting_domain, + soundness_type: whir_parameters.soundness_type, + starting_log_inv_rate: whir_parameters.starting_log_inv_rate, + starting_folding_pow_bits, + folding_factor: whir_parameters.folding_factor, + round_parameters, + final_queries, + final_pow_bits, + final_sumcheck_rounds, + final_folding_pow_bits, + pow_strategy: PhantomData, + fold_optimisation: whir_parameters.fold_optimisation, + final_log_inv_rate: log_inv_rate, + leaf_hash_params: whir_parameters.leaf_hash_params, + two_to_one_params: whir_parameters.two_to_one_params, + } + } + + pub fn n_rounds(&self) -> usize { + self.round_parameters.len() + } + + pub fn check_pow_bits(&self) -> bool { + [ + self.starting_folding_pow_bits, + self.final_pow_bits, + self.final_folding_pow_bits, + ] + .into_iter() + .all(|x| x <= self.max_pow_bits as f64) + && self.round_parameters.iter().all(|r| { + r.pow_bits <= self.max_pow_bits as f64 + && r.folding_pow_bits <= self.max_pow_bits as f64 + }) + } + + pub fn log_eta(soundness_type: SoundnessType, log_inv_rate: usize) -> f64 { + // Ask me how I did this? At the time, only God and I knew. Now only God knows + match soundness_type { + SoundnessType::ProvableList => -(0.5 * log_inv_rate as f64 + LOG2_10 + 1.), + SoundnessType::UniqueDecoding => 0., + SoundnessType::ConjectureList => -(log_inv_rate as f64 + 1.), + } + } + + pub fn list_size_bits( + soundness_type: SoundnessType, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + match soundness_type { + SoundnessType::ConjectureList => (num_variables + log_inv_rate) as f64 - log_eta, + SoundnessType::ProvableList => { + let log_inv_sqrt_rate: f64 = log_inv_rate as f64 / 2.; + log_inv_sqrt_rate - (1. + log_eta) + } + SoundnessType::UniqueDecoding => 0.0, + } + } + + pub fn rbr_ood_sample( + soundness_type: SoundnessType, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + field_size_bits: usize, + ood_samples: usize, + ) -> f64 { + let list_size_bits = + Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); + + let error = 2. * list_size_bits + (num_variables * ood_samples) as f64; + (ood_samples * field_size_bits) as f64 + 1. - error + } + + pub fn ood_samples( + security_level: usize, // We don't do PoW for OOD + soundness_type: SoundnessType, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + field_size_bits: usize, + ) -> usize { + if matches!(soundness_type, SoundnessType::UniqueDecoding) { + 0 + } else { + for ood_samples in 1..64 { + if Self::rbr_ood_sample( + soundness_type, + num_variables, + log_inv_rate, + log_eta, + field_size_bits, + ood_samples, + ) >= security_level as f64 + { + return ood_samples; + } + } + + panic!("Could not find an appropriate number of OOD samples"); + } + } + + // Compute the proximity gaps term of the fold + pub fn rbr_soundness_fold_prox_gaps( + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + // Recall, at each round we are only folding by two at a time + let error = match soundness_type { + SoundnessType::ConjectureList => (num_variables + log_inv_rate) as f64 - log_eta, + SoundnessType::ProvableList => { + LOG2_10 + 3.5 * log_inv_rate as f64 + 2. * num_variables as f64 + } + SoundnessType::UniqueDecoding => (num_variables + log_inv_rate) as f64, + }; + + field_size_bits as f64 - error + } + + pub fn rbr_soundness_fold_sumcheck( + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + let list_size = Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); + + field_size_bits as f64 - (list_size + 1.) + } + + pub fn folding_pow_bits( + security_level: usize, + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + soundness_type, + field_size_bits, + num_variables, + log_inv_rate, + log_eta, + ); + let sumcheck_error = Self::rbr_soundness_fold_sumcheck( + soundness_type, + field_size_bits, + num_variables, + log_inv_rate, + log_eta, + ); + + let error = prox_gaps_error.min(sumcheck_error); + + 0_f64.max(security_level as f64 - error) + } + + // Used to select the number of queries + pub fn queries( + soundness_type: SoundnessType, + protocol_security_level: usize, + log_inv_rate: usize, + ) -> usize { + let num_queries_f = match soundness_type { + SoundnessType::UniqueDecoding => { + let rate = 1. / ((1 << log_inv_rate) as f64); + let denom = (0.5 * (1. + rate)).log2(); + + -(protocol_security_level as f64) / denom + } + SoundnessType::ProvableList => { + (2 * protocol_security_level) as f64 / log_inv_rate as f64 + } + SoundnessType::ConjectureList => protocol_security_level as f64 / log_inv_rate as f64, + }; + num_queries_f.ceil() as usize + } + + // This is the bits of security of the query step + pub fn rbr_queries( + soundness_type: SoundnessType, + log_inv_rate: usize, + num_queries: usize, + ) -> f64 { + let num_queries = num_queries as f64; + + match soundness_type { + SoundnessType::UniqueDecoding => { + let rate = 1. / ((1 << log_inv_rate) as f64); + let denom = -(0.5 * (1. + rate)).log2(); + + num_queries * denom + } + SoundnessType::ProvableList => num_queries * 0.5 * log_inv_rate as f64, + SoundnessType::ConjectureList => num_queries * log_inv_rate as f64, + } + } + + pub fn rbr_soundness_queries_combination( + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ood_samples: usize, + num_queries: usize, + ) -> f64 { + let list_size = Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); + + let log_combination = ((ood_samples + num_queries) as f64).log2(); + + field_size_bits as f64 - (log_combination + list_size + 1.) + } +} + +impl Display for WhirConfig +where + F: FftField, + MerkleConfig: Config, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt::Display::fmt(&self.mv_parameters, f)?; + writeln!(f, ", folding factor: {:?}", self.folding_factor)?; + writeln!( + f, + "Security level: {} bits using {} security and {} bits of PoW", + self.security_level, self.soundness_type, self.max_pow_bits + )?; + + writeln!( + f, + "initial_folding_pow_bits: {}", + self.starting_folding_pow_bits + )?; + for r in &self.round_parameters { + fmt::Display::fmt(&r, f)?; + } + + writeln!( + f, + "final_queries: {}, final_rate: 2^-{}, final_pow_bits: {}, final_folding_pow_bits: {}", + self.final_queries, + self.final_log_inv_rate, + self.final_pow_bits, + self.final_folding_pow_bits, + )?; + + writeln!(f, "------------------------------------")?; + writeln!(f, "Round by round soundness analysis:")?; + writeln!(f, "------------------------------------")?; + + let field_size_bits = F::field_size_in_bits(); + let log_eta = Self::log_eta(self.soundness_type, self.starting_log_inv_rate); + let mut num_variables = self.mv_parameters.num_variables; + + if self.committment_ood_samples > 0 { + writeln!( + f, + "{:.1} bits -- OOD commitment", + Self::rbr_ood_sample( + self.soundness_type, + num_variables, + self.starting_log_inv_rate, + log_eta, + field_size_bits, + self.committment_ood_samples + ) + )?; + } + + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + self.soundness_type, + field_size_bits, + num_variables, + self.starting_log_inv_rate, + log_eta, + ); + let sumcheck_error = Self::rbr_soundness_fold_sumcheck( + self.soundness_type, + field_size_bits, + num_variables, + self.starting_log_inv_rate, + log_eta, + ); + writeln!( + f, + "{:.1} bits -- (x{:?}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", + prox_gaps_error.min(sumcheck_error) + self.starting_folding_pow_bits, + self.folding_factor, + prox_gaps_error, + sumcheck_error, + self.starting_folding_pow_bits, + )?; + + num_variables -= self.folding_factor.at_round(0); + + for (round, r) in self.round_parameters.iter().enumerate() { + let next_rate = r.log_inv_rate + (self.folding_factor.at_round(round) - 1); + let log_eta = Self::log_eta(self.soundness_type, next_rate); + + if r.ood_samples > 0 { + writeln!( + f, + "{:.1} bits -- OOD sample", + Self::rbr_ood_sample( + self.soundness_type, + num_variables, + next_rate, + log_eta, + field_size_bits, + r.ood_samples + ) + )?; + } + + let query_error = Self::rbr_queries(self.soundness_type, r.log_inv_rate, r.num_queries); + let combination_error = Self::rbr_soundness_queries_combination( + self.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_eta, + r.ood_samples, + r.num_queries, + ); + writeln!( + f, + "{:.1} bits -- query error: {:.1}, combination: {:.1}, pow: {:.1}", + query_error.min(combination_error) + r.pow_bits, + query_error, + combination_error, + r.pow_bits, + )?; + + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + self.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_eta, + ); + let sumcheck_error = Self::rbr_soundness_fold_sumcheck( + self.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_eta, + ); + + writeln!( + f, + "{:.1} bits -- (x{:?}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", + prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits, + self.folding_factor, + prox_gaps_error, + sumcheck_error, + r.folding_pow_bits, + )?; + + num_variables -= self.folding_factor.at_round(round + 1); + } + + let query_error = Self::rbr_queries( + self.soundness_type, + self.final_log_inv_rate, + self.final_queries, + ); + writeln!( + f, + "{:.1} bits -- query error: {:.1}, pow: {:.1}", + query_error + self.final_pow_bits, + query_error, + self.final_pow_bits, + )?; + + if self.final_sumcheck_rounds > 0 { + let combination_error = field_size_bits as f64 - 1.; + writeln!( + f, + "{:.1} bits -- (x{}) combination: {:.1}, pow: {:.1}", + combination_error + self.final_pow_bits, + self.final_sumcheck_rounds, + combination_error, + self.final_folding_pow_bits, + )?; + } + + Ok(()) + } +} + +impl Display for RoundConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Num_queries: {}, rate: 2^-{}, pow_bits: {}, ood_samples: {}, folding_pow: {}", + self.num_queries, + self.log_inv_rate, + self.pow_bits, + self.ood_samples, + self.folding_pow_bits, + ) + } +} diff --git a/whir/src/whir/prover.rs b/whir/src/whir/prover.rs new file mode 100644 index 000000000..a1e0bcd8b --- /dev/null +++ b/whir/src/whir/prover.rs @@ -0,0 +1,456 @@ +use super::{Statement, WhirProof, committer::Witness, parameters::WhirConfig}; +use crate::{ + domain::Domain, + ntt::expand_from_coeff, + parameters::FoldType, + poly_utils::{ + MultilinearPoint, + coeffs::CoefficientList, + fold::{compute_fold, restructure_evaluations}, + }, + sumcheck::prover_not_skipping::SumcheckProverNotSkipping, + utils::{self, expand_randomness}, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use nimue::{ + ByteChallenges, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::{self, PoWChallenge}; + +use crate::whir::fs_utils::{DigestWriter, get_challenge_stir_queries}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub struct Prover(pub WhirConfig) +where + F: FftField, + MerkleConfig: Config; + +impl Prover +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + pub(crate) fn validate_parameters(&self) -> bool { + self.0.mv_parameters.num_variables + == self.0.folding_factor.total_number(self.0.n_rounds()) + self.0.final_sumcheck_rounds + } + + fn validate_statement(&self, statement: &Statement) -> bool { + if statement.points.len() != statement.evaluations.len() { + return false; + } + if !statement + .points + .iter() + .all(|point| point.0.len() == self.0.mv_parameters.num_variables) + { + return false; + } + if !self.0.initial_statement && !statement.points.is_empty() { + return false; + } + true + } + + fn validate_witness(&self, witness: &Witness) -> bool { + assert_eq!(witness.ood_points.len(), witness.ood_answers.len()); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } + witness.polynomial.num_variables() == self.0.mv_parameters.num_variables + } + + pub fn prove( + &self, + merlin: &mut Merlin, + mut statement: Statement, + witness: Witness, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + // If any evaluation point is shorter than the folding factor, pad with 0 in front + for p in statement.points.iter_mut() { + while p.n_variables() < self.0.folding_factor.at_round(0) { + p.0.insert(0, F::ONE); + } + } + + assert!(self.validate_parameters()); + assert!(self.validate_statement(&statement)); + assert!(self.validate_witness(&witness)); + + let timer = start_timer!(|| "Single Prover"); + let initial_claims: Vec<_> = witness + .ood_points + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.0.mv_parameters.num_variables, + ) + }) + .chain(statement.points) + .collect(); + let initial_answers: Vec<_> = witness + .ood_answers + .into_iter() + .chain(statement.evaluations) + .collect(); + + if !self.0.initial_statement { + // It is ensured that if there is no initial statement, the + // number of ood samples is also zero. + assert!( + initial_answers.is_empty(), + "Can not have initial answers without initial statement" + ); + } + + let mut sumcheck_prover = None; + let folding_randomness = if self.0.initial_statement { + // If there is initial statement, then we run the sum-check for + // this initial statement. + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + + sumcheck_prover = Some(SumcheckProverNotSkipping::new( + witness.polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + + sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(0), + self.0.starting_folding_pow_bits, + )? + } else { + // If there is no initial statement, there is no need to run the + // initial rounds of the sum-check, and the verifier directly sends + // the initial folding randomnesses. + let mut folding_randomness = vec![F::ZERO; self.0.folding_factor.at_round(0)]; + merlin.fill_challenge_scalars(&mut folding_randomness)?; + + if self.0.starting_folding_pow_bits > 0. { + merlin.challenge_pow::(self.0.starting_folding_pow_bits)?; + } + MultilinearPoint(folding_randomness) + }; + + let round_state = RoundState { + domain: self.0.starting_domain.clone(), + round: 0, + sumcheck_prover, + folding_randomness, + coefficients: witness.polynomial, + prev_merkle: witness.merkle_tree, + prev_merkle_answers: witness.merkle_leaves, + merkle_proofs: vec![], + }; + + let round_timer = start_timer!(|| "Single Round"); + let result = self.round(merlin, round_state); + end_timer!(round_timer); + + end_timer!(timer); + + result + } + + pub(crate) fn round( + &self, + merlin: &mut Merlin, + mut round_state: RoundState, + ) -> ProofResult> + where + Merlin: FieldChallenges + + ByteChallenges + + FieldWriter + + ByteWriter + + PoWChallenge + + DigestWriter, + { + // Fold the coefficients + let folded_coefficients = round_state + .coefficients + .fold(&round_state.folding_randomness); + + let num_variables = self.0.mv_parameters.num_variables + - self.0.folding_factor.total_number(round_state.round); + // num_variables should match the folded_coefficients here. + assert_eq!(num_variables, folded_coefficients.num_variables()); + + // Base case + if round_state.round == self.0.n_rounds() { + // Directly send coefficients of the polynomial to the verifier. + merlin.add_scalars(folded_coefficients.coeffs())?; + + // Final verifier queries and answers. The indices are over the + // *folded* domain. + let final_challenge_indexes = get_challenge_stir_queries( + round_state.domain.size(), // The size of the *original* domain before folding + self.0.folding_factor.at_round(round_state.round), /* The folding factor we used to fold the previous polynomial */ + self.0.final_queries, + merlin, + )?; + + let merkle_proof = round_state + .prev_merkle + .generate_multi_proof(final_challenge_indexes.clone()) + .unwrap(); + // Every query requires opening these many in the previous Merkle tree + let fold_size = 1 << self.0.folding_factor.at_round(round_state.round); + let answers = final_challenge_indexes + .into_iter() + .map(|i| { + round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec() + }) + .collect(); + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if self.0.final_pow_bits > 0. { + merlin.challenge_pow::(self.0.final_pow_bits)?; + } + + // Final sumcheck + if self.0.final_sumcheck_rounds > 0 { + round_state + .sumcheck_prover + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new(folded_coefficients.clone(), &[], &[], &[]) + }) + .compute_sumcheck_polynomials::( + merlin, + self.0.final_sumcheck_rounds, + self.0.final_folding_pow_bits, + )?; + } + + return Ok(WhirProof(round_state.merkle_proofs)); + } + + let round_params = &self.0.round_parameters[round_state.round]; + + // Fold the coefficients, and compute fft of polynomial (and commit) + let new_domain = round_state.domain.scale(2); + let expansion = new_domain.size() / folded_coefficients.num_coeffs(); + let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); + // Group the evaluations into leaves by the *next* round folding factor + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. + let folded_evals = utils::stack_evaluations( + evals, + self.0.folding_factor.at_round(round_state.round + 1), // Next round fold factor + ); + let folded_evals = restructure_evaluations( + folded_evals, + self.0.fold_optimisation, + new_domain.backing_domain.group_gen(), + new_domain.backing_domain.group_gen_inv(), + self.0.folding_factor.at_round(round_state.round + 1), + ); + + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact( + 1 << self + .0 + .folding_factor + .get_folding_factor_of_round(round_state.round + 1), + ); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals + .par_chunks_exact(1 << self.0.folding_factor.at_round(round_state.round + 1)); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + + let root = merkle_tree.root(); + merlin.add_digest(root)?; + + // OOD Samples + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = Vec::with_capacity(round_params.ood_samples); + if round_params.ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers.extend(ood_points.iter().map(|ood_point| { + folded_coefficients.evaluate(&MultilinearPoint::expand_from_univariate( + *ood_point, + num_variables, + )) + })); + merlin.add_scalars(&ood_answers)?; + } + + // STIR queries + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain.size(), // Current domain size *before* folding + self.0.folding_factor.at_round(round_state.round), // Current fold factor + round_params.num_queries, + merlin, + )?; + // Compute the generator of the folded domain, in the extension field + let domain_scaled_gen = round_state + .domain + .backing_domain + .element(1 << self.0.folding_factor.at_round(round_state.round)); + let stir_challenges: Vec<_> = ood_points + .into_iter() + .chain( + stir_challenges_indexes + .iter() + .map(|i| domain_scaled_gen.pow([*i as u64])), + ) + .map(|univariate| MultilinearPoint::expand_from_univariate(univariate, num_variables)) + .collect(); + + let merkle_proof = round_state + .prev_merkle + .generate_multi_proof(stir_challenges_indexes.clone()) + .unwrap(); + let fold_size = 1 << self.0.folding_factor.at_round(round_state.round); + let answers: Vec<_> = stir_challenges_indexes + .iter() + .map(|i| round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) + .collect(); + // Evaluate answers in the folding randomness. + let mut stir_evaluations = ood_answers.clone(); + match self.0.fold_optimisation { + FoldType::Naive => { + // See `Verifier::compute_folds_full` + let domain_size = round_state.domain.backing_domain.size(); + let domain_gen = round_state.domain.backing_domain.element(1); + let domain_gen_inv = domain_gen.inverse().unwrap(); + let coset_domain_size = 1 << self.0.folding_factor.at_round(round_state.round); + // The domain (before folding) is split into cosets of size + // `coset_domain_size` (which is just `fold_size`). Each coset + // is generated by powers of `coset_generator` (which is just the + // `fold_size`-root of unity) multiplied by a different + // `coset_offset`. + // For example, if `fold_size = 16`, and the domain size is N, then + // the domain is (1, w, w^2, ..., w^(N-1)), the domain generator + // is w, and the coset generator is w^(N/16). + // The first coset is (1, w^(N/16), w^(2N/16), ..., w^(15N/16)) + // which is also a subgroup itself (the coset_offset is 1). + // The second coset would be w * , the third coset would be + // w^2 * , and so on. Until w^(N/16-1) * . + let coset_generator_inv = + domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); + stir_evaluations.extend(stir_challenges_indexes.iter().zip(&answers).map( + |(index, answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = domain_gen_inv.pow([*index as u64]); + + // In the Naive mode, the oracle consists directly of the + // evaluations of f over the domain. We leverage an + // algorithm to compute the evaluations of the folded f + // at the corresponding point in folded domain (which is + // coset_offset^fold_size). + compute_fold( + answers, + &round_state.folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + F::from(2).inverse().unwrap(), + self.0.folding_factor.at_round(round_state.round), + ) + }, + )) + } + FoldType::ProverHelps => stir_evaluations.extend(answers.iter().map(|answers| { + // In the ProverHelps mode, the oracle values have been linearly + // transformed such that they are exactly the coefficients of the + // multilinear polynomial whose evaluation at the folding randomness + // is just the folding of f evaluated at the folded point. + CoefficientList::new(answers.to_vec()).evaluate(&round_state.folding_randomness) + })), + } + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if round_params.pow_bits > 0. { + merlin.challenge_pow::(round_params.pow_bits)?; + } + + // Randomness for combination + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, stir_challenges.len()); + + let mut sumcheck_prover = round_state + .sumcheck_prover + .take() + .map(|mut sumcheck_prover| { + sumcheck_prover.add_new_equality( + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ); + sumcheck_prover + }) + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new( + folded_coefficients.clone(), + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ) + }); + + let folding_randomness = sumcheck_prover + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(round_state.round + 1), + round_params.folding_pow_bits, + )?; + + let round_state = RoundState { + round: round_state.round + 1, + domain: new_domain, + sumcheck_prover: Some(sumcheck_prover), + folding_randomness, + coefficients: folded_coefficients, /* TODO: Is this redundant with `sumcheck_prover.coeff` ? */ + prev_merkle: merkle_tree, + prev_merkle_answers: folded_evals, + merkle_proofs: round_state.merkle_proofs, + }; + + self.round(merlin, round_state) + } +} + +pub(crate) struct RoundState +where + F: FftField, + MerkleConfig: Config, +{ + pub(crate) round: usize, + pub(crate) domain: Domain, + pub(crate) sumcheck_prover: Option>, + pub(crate) folding_randomness: MultilinearPoint, + pub(crate) coefficients: CoefficientList, + pub(crate) prev_merkle: MerkleTree, + pub(crate) prev_merkle_answers: Vec, + pub(crate) merkle_proofs: Vec<(MultiPath, Vec>)>, +} diff --git a/whir/src/whir/verifier.rs b/whir/src/whir/verifier.rs new file mode 100644 index 000000000..0cbe366e9 --- /dev/null +++ b/whir/src/whir/verifier.rs @@ -0,0 +1,631 @@ +use std::iter; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use nimue::{ + ByteChallenges, ByteReader, ProofError, ProofResult, + plugins::ark::{FieldChallenges, FieldReader}, +}; +use nimue_pow::{self, PoWChallenge}; + +use super::{Statement, WhirProof, parameters::WhirConfig}; +use crate::{ + parameters::FoldType, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside, fold::compute_fold}, + sumcheck::proof::SumcheckPolynomial, + utils::expand_randomness, + whir::fs_utils::{DigestReader, get_challenge_stir_queries}, +}; + +pub struct Verifier +where + F: FftField, + MerkleConfig: Config, +{ + pub(crate) params: WhirConfig, + pub(crate) two_inv: F, +} + +#[derive(Clone)] +pub(crate) struct ParsedCommitment { + pub(crate) root: D, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +#[derive(Clone)] +pub(crate) struct ParsedProof { + pub(crate) initial_combination_randomness: Vec, + pub(crate) initial_sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, + pub(crate) rounds: Vec>, + pub(crate) final_domain_gen_inv: F, + pub(crate) final_randomness_indexes: Vec, + pub(crate) final_randomness_points: Vec, + pub(crate) final_randomness_answers: Vec>, + pub(crate) final_folding_randomness: MultilinearPoint, + pub(crate) final_sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, + pub(crate) final_sumcheck_randomness: MultilinearPoint, + pub(crate) final_coefficients: CoefficientList, +} + +#[derive(Debug, Clone)] +pub(crate) struct ParsedRound { + pub(crate) folding_randomness: MultilinearPoint, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, + pub(crate) stir_challenges_indexes: Vec, + pub(crate) stir_challenges_points: Vec, + pub(crate) stir_challenges_answers: Vec>, + pub(crate) combination_randomness: Vec, + pub(crate) sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, + pub(crate) domain_gen_inv: F, +} + +impl Verifier +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + pub fn new(params: WhirConfig) -> Self { + Verifier { + params, + two_inv: F::from(2).inverse().unwrap(), // The only inverse in the entire code :) + } + } + + fn parse_commitment( + &self, + arthur: &mut Arthur, + ) -> ProofResult> + where + Arthur: ByteReader + FieldReader + FieldChallenges + DigestReader, + { + let root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples]; + if self.params.committment_ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + Ok(ParsedCommitment { + root, + ood_points, + ood_answers, + }) + } + + fn parse_proof( + &self, + arthur: &mut Arthur, + parsed_commitment: &ParsedCommitment, + statement: &Statement, // Will be needed later + whir_proof: &WhirProof, + ) -> ProofResult> + where + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, + { + let mut sumcheck_rounds = Vec::new(); + let mut folding_randomness: MultilinearPoint; + let initial_combination_randomness; + if self.params.initial_statement { + // Derive combination randomness and first sumcheck polynomial + let [combination_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + initial_combination_randomness = expand_randomness( + combination_randomness_gen, + parsed_commitment.ood_points.len() + statement.points.len(), + ); + + // Initial sumcheck + sumcheck_rounds.reserve_exact(self.params.folding_factor.at_round(0)); + for _ in 0..self.params.folding_factor.at_round(0) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + + folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + } else { + assert_eq!(parsed_commitment.ood_points.len(), 0); + assert_eq!(statement.points.len(), 0); + + initial_combination_randomness = vec![F::ONE]; + + let mut folding_randomness_vec = vec![F::ZERO; self.params.folding_factor.at_round(0)]; + arthur.fill_challenge_scalars(&mut folding_randomness_vec)?; + folding_randomness = MultilinearPoint(folding_randomness_vec); + + // PoW + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + }; + + let mut prev_root = parsed_commitment.root.clone(); + let mut domain_gen = self.params.starting_domain.backing_domain.group_gen(); + let mut exp_domain_gen = domain_gen.pow([1 << self.params.folding_factor.at_round(0)]); + let mut domain_gen_inv = self.params.starting_domain.backing_domain.group_gen_inv(); + let mut domain_size = self.params.starting_domain.size(); + let mut rounds = vec![]; + + for r in 0..self.params.n_rounds() { + let (merkle_proof, answers) = &whir_proof.0[r]; + let round_params = &self.params.round_parameters[r]; + + let new_root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = vec![F::ZERO; round_params.ood_samples]; + if round_params.ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + let stir_challenges_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(r), + round_params.num_queries, + arthur, + )?; + + let stir_challenges_points = stir_challenges_indexes + .iter() + .map(|index| exp_domain_gen.pow([*index as u64])) + .collect(); + + if !merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || merkle_proof.leaf_indexes != stir_challenges_indexes + { + return Err(ProofError::InvalidProof); + } + + if round_params.pow_bits > 0. { + arthur.challenge_pow::(round_params.pow_bits)?; + } + + let [combination_randomness_gen] = arthur.challenge_scalars()?; + let combination_randomness = expand_randomness( + combination_randomness_gen, + stir_challenges_indexes.len() + round_params.ood_samples, + ); + + let mut sumcheck_rounds = + Vec::with_capacity(self.params.folding_factor.at_round(r + 1)); + for _ in 0..self.params.folding_factor.at_round(r + 1) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if round_params.folding_pow_bits > 0. { + arthur.challenge_pow::(round_params.folding_pow_bits)?; + } + } + + let new_folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + + rounds.push(ParsedRound { + folding_randomness, + ood_points, + ood_answers, + stir_challenges_indexes, + stir_challenges_points, + stir_challenges_answers: answers.to_vec(), + combination_randomness, + sumcheck_rounds, + domain_gen_inv, + }); + + folding_randomness = new_folding_randomness; + + prev_root = new_root.clone(); + domain_gen = domain_gen * domain_gen; + exp_domain_gen = domain_gen.pow([1 << self.params.folding_factor.at_round(r + 1)]); + domain_gen_inv = domain_gen_inv * domain_gen_inv; + domain_size /= 2; + } + + let mut final_coefficients = vec![F::ZERO; 1 << self.params.final_sumcheck_rounds]; + arthur.fill_next_scalars(&mut final_coefficients)?; + let final_coefficients = CoefficientList::new(final_coefficients); + + // Final queries verify + let final_randomness_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(self.params.n_rounds()), + self.params.final_queries, + arthur, + )?; + let final_randomness_points = final_randomness_indexes + .iter() + .map(|index| exp_domain_gen.pow([*index as u64])) + .collect(); + + let (final_merkle_proof, final_randomness_answers) = &whir_proof.0[whir_proof.0.len() - 1]; + if !final_merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + final_randomness_answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || final_merkle_proof.leaf_indexes != final_randomness_indexes + { + return Err(ProofError::InvalidProof); + } + + if self.params.final_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_pow_bits)?; + } + + let mut final_sumcheck_rounds = Vec::with_capacity(self.params.final_sumcheck_rounds); + for _ in 0..self.params.final_sumcheck_rounds { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + final_sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.final_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_folding_pow_bits)?; + } + } + let final_sumcheck_randomness = MultilinearPoint( + final_sumcheck_rounds + .iter() + .map(|&(_, r)| r) + .rev() + .collect(), + ); + + Ok(ParsedProof { + initial_combination_randomness, + initial_sumcheck_rounds: sumcheck_rounds, + rounds, + final_domain_gen_inv: domain_gen_inv, + final_folding_randomness: folding_randomness, + final_randomness_indexes, + final_randomness_points, + final_randomness_answers: final_randomness_answers.to_vec(), + final_sumcheck_rounds, + final_sumcheck_randomness, + final_coefficients, + }) + } + + fn compute_v_poly( + &self, + parsed_commitment: &ParsedCommitment, + statement: &Statement, + proof: &ParsedProof, + ) -> F { + let mut num_variables = self.params.mv_parameters.num_variables; + + let mut folding_randomness = MultilinearPoint( + iter::once(&proof.final_sumcheck_randomness.0) + .chain(iter::once(&proof.final_folding_randomness.0)) + .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) + .flatten() + .copied() + .collect(), + ); + + let statement_points: Vec> = statement + .points + .clone() + .into_iter() + .map(|mut p| { + while p.n_variables() < self.params.folding_factor.at_round(0) { + p.0.insert(0, F::ONE); + } + p + }) + .collect(); + let mut value = parsed_commitment + .ood_points + .iter() + .map(|ood_point| MultilinearPoint::expand_from_univariate(*ood_point, num_variables)) + .chain(statement_points) + .zip(&proof.initial_combination_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(&point, &folding_randomness)) + .sum(); + + for (round, round_proof) in proof.rounds.iter().enumerate() { + num_variables -= self.params.folding_factor.at_round(round); + folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); + + let ood_points = &round_proof.ood_points; + let stir_challenges_points = &round_proof.stir_challenges_points; + let stir_challenges: Vec<_> = ood_points + .iter() + .chain(stir_challenges_points) + .cloned() + .map(|univariate| { + MultilinearPoint::expand_from_univariate(univariate, num_variables) + // TODO: + // Maybe refactor outside + }) + .collect(); + + let sum_of_claims: F = stir_challenges + .into_iter() + .map(|point| eq_poly_outside(&point, &folding_randomness)) + .zip(&round_proof.combination_randomness) + .map(|(point, rand)| point * rand) + .sum(); + + value += sum_of_claims; + } + + value + } + + pub(crate) fn compute_folds(&self, parsed: &ParsedProof) -> Vec> { + match self.params.fold_optimisation { + FoldType::Naive => self.compute_folds_full(parsed), + FoldType::ProverHelps => self.compute_folds_helped(parsed), + } + } + + fn compute_folds_full(&self, parsed: &ParsedProof) -> Vec> { + let mut domain_size = self.params.starting_domain.backing_domain.size(); + + let mut result = Vec::new(); + + for (round_index, round) in parsed.rounds.iter().enumerate() { + let coset_domain_size = 1 << self.params.folding_factor.at_round(round_index); + // This is such that coset_generator^coset_domain_size = F::ONE + // let _coset_generator = domain_gen.pow(&[(domain_size / coset_domain_size) as u64]); + let coset_generator_inv = round + .domain_gen_inv + .pow([(domain_size / coset_domain_size) as u64]); + + let evaluations: Vec<_> = round + .stir_challenges_indexes + .iter() + .zip(&round.stir_challenges_answers) + .map(|(index, answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = round.domain_gen_inv.pow([*index as u64]); + + compute_fold( + answers, + &round.folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + self.two_inv, + self.params.folding_factor.at_round(round_index), + ) + }) + .collect(); + result.push(evaluations); + domain_size /= 2; + } + + let coset_domain_size = 1 << self.params.folding_factor.at_round(parsed.rounds.len()); + let domain_gen_inv = parsed.final_domain_gen_inv; + + // Final round + let coset_generator_inv = domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); + let evaluations: Vec<_> = parsed + .final_randomness_indexes + .iter() + .zip(&parsed.final_randomness_answers) + .map(|(index, answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = domain_gen_inv.pow([*index as u64]); + + compute_fold( + answers, + &parsed.final_folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + self.two_inv, + self.params.folding_factor.at_round(parsed.rounds.len()), + ) + }) + .collect(); + result.push(evaluations); + + result + } + + fn compute_folds_helped(&self, parsed: &ParsedProof) -> Vec> { + let mut result = Vec::new(); + + for round in &parsed.rounds { + let evaluations: Vec<_> = round + .stir_challenges_answers + .iter() + .map(|answers| { + CoefficientList::new(answers.to_vec()).evaluate(&round.folding_randomness) + }) + .collect(); + result.push(evaluations); + } + + // Final round + let evaluations: Vec<_> = parsed + .final_randomness_answers + .iter() + .map(|answers| { + CoefficientList::new(answers.to_vec()).evaluate(&parsed.final_folding_randomness) + }) + .collect(); + result.push(evaluations); + + result + } + + pub fn verify( + &self, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment(arthur)?; + let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof)?; + + let computed_folds = self.compute_folds(&parsed); + + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != parsed_commitment + .ood_answers + .iter() + .copied() + .chain(statement.evaluations.clone()) + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() + { + return Err(ProofError::InvalidProof); + } + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); + } + + for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { + let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); + + let values = round.ood_answers.iter().copied().chain(folds.clone()); + + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + + values + .zip(&round.combination_randomness) + .map(|(val, rand)| val * rand) + .sum::(); + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + // Check the foldings computed from the proof match the evaluations of the polynomial + let final_folds = &computed_folds[computed_folds.len() - 1]; + let final_evaluations = parsed + .final_coefficients + .evaluate_at_univariate(&parsed.final_randomness_points); + if !final_folds + .iter() + .zip(final_evaluations) + .all(|(&fold, eval)| fold == eval) + { + return Err(ProofError::InvalidProof); + } + + // Check the final sumchecks + if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); + let claimed_sum = prev_sumcheck_poly_eval; + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + + // Check the final sumcheck evaluation + let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, statement, &parsed); + + if prev_sumcheck_poly_eval + != evaluation_of_v_poly + * parsed + .final_coefficients + .evaluate(&parsed.final_sumcheck_randomness) + { + return Err(ProofError::InvalidProof); + } + + Ok(parsed_commitment.root) + } +}