diff --git a/Cargo.lock b/Cargo.lock index 29a5ffb7f..a025f7e29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3153,7 +3153,6 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/whir?branch=feat%2Fceno-binding-batch#cc05cba75d96a5c3caa78e06a6279ca741954c2b" dependencies = [ "ark-crypto-primitives", "ark-ff", diff --git a/Cargo.toml b/Cargo.toml index 462acb8c5..615ed199f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "poseidon", "sumcheck", "transcript", + "whir", ] resolver = "2" diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 318ff6a31..54305ecc7 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -35,7 +35,7 @@ rand_chacha.workspace = true rayon = { workspace = true, optional = true } serde.workspace = true transcript = { path = "../transcript" } -whir = { git = "https://github.com/scroll-tech/whir", branch = "feat/ceno-binding-batch", features = ["ceno"] } +whir = { path = "../whir", features = ["ceno"] } zeroize = "1.8" [dev-dependencies] diff --git a/whir/.github/workflows/rust.yml b/whir/.github/workflows/rust.yml new file mode 100644 index 000000000..11b7199e8 --- /dev/null +++ b/whir/.github/workflows/rust.yml @@ -0,0 +1,24 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Switch toolchain + run: rustup update nightly && rustup default nightly + - name: Build + run: cargo build --release --verbose + - name: Run tests + run: cargo test --release --verbose diff --git a/whir/.gitignore b/whir/.gitignore new file mode 100644 index 000000000..83fff18b6 --- /dev/null +++ b/whir/.gitignore @@ -0,0 +1,12 @@ +/target +scripts/temp/ +bench_utils/target +*_proof +artifacts +outputs/temp/ +*.pdf +scripts/__pycache__/ +.DS_Store +outputs/ +.idea +.vscode \ No newline at end of file diff --git a/whir/Cargo.lock b/whir/Cargo.lock new file mode 100644 index 000000000..c3207f271 --- /dev/null +++ b/whir/Cargo.lock @@ -0,0 +1,1300 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "allocator-api2" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" + +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys", +] + +[[package]] +name = "ark-crypto-primitives" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0c292754729c8a190e50414fd1a37093c786c709899f29c9f7daccecfa855e" +dependencies = [ + "ahash", + "ark-crypto-primitives-macros", + "ark-ec", + "ark-ff", + "ark-relations", + "ark-serialize", + "ark-snark", + "ark-std", + "blake2", + "derivative", + "digest", + "fnv", + "hashbrown 0.14.5", + "merlin", + "rayon", + "sha2", +] + +[[package]] +name = "ark-crypto-primitives-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7e89fe77d1f0f4fe5b96dfc940923d88d17b6a773808124f21e764dfb063c6a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown 0.15.1", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "rayon", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "rayon", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std", + "educe", + "fnv", + "hashbrown 0.15.1", + "rayon", +] + +[[package]] +name = "ark-relations" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec46ddc93e7af44bcab5230937635b06fb5744464dd6a7e7b083e80ebd274384" +dependencies = [ + "ark-ff", + "ark-std", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "arrayvec", + "digest", + "num-bigint", + "rayon", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ark-snark" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d368e2848c2d4c129ce7679a7d0d2d612b6a274d3ea6a13bad4445d61b381b88" +dependencies = [ + "ark-ff", + "ark-relations", + "ark-serialize", + "ark-std", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "colored", + "num-traits", + "rand", + "rayon", +] + +[[package]] +name = "ark-test-curves" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc137bb3271671a597ae79157030f88affe9fa7975207c3b781a01cb9ed0372" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-std", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bytemuck" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cc" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clap" +version = "4.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "clap_lex" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + +[[package]] +name = "colored" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys", +] + +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + +[[package]] +name = "cpufeatures" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +dependencies = [ + "libc", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "unicode-xid", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "goldilocks" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno-Goldilocks#29a15d186ce4375dab346a3cc9eca6e43540cb8d" +dependencies = [ + "ff", + "halo2curves", + "itertools 0.12.1", + "rand_core", + "serde", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core", + "subtle", +] + +[[package]] +name = "halo2curves" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b1142bd1059aacde1b477e0c80c142910f1ceae67fc619311d6a17428007ab" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "num-bigint", + "num-traits", + "pasta_curves", + "paste", + "rand", + "rand_core", + "static_assertions", + "subtle", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "allocator-api2", +] + +[[package]] +name = "hashbrown" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +dependencies = [ + "allocator-api2", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] + +[[package]] +name = "libc" +version = "0.2.162" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "merlin" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d" +dependencies = [ + "byteorder", + "keccak", + "rand_core", + "zeroize", +] + +[[package]] +name = "nimue" +version = "0.2.0" +source = "git+https://github.com/arkworks-rs/nimue#b28eb124420eede1a890e3c64b37adfff94938d3" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "digest", + "generic-array", + "hex", + "keccak", + "log", + "rand", + "zeroize", +] + +[[package]] +name = "nimue-pow" +version = "0.1.0" +source = "git+https://github.com/arkworks-rs/nimue#b28eb124420eede1a890e3c64b37adfff94938d3" +dependencies = [ + "blake3", + "bytemuck", + "keccak", + "nimue", + "rayon", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "rand", + "static_assertions", + "subtle", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "serde_json" +version = "1.0.132" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] + +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "whir" +version = "0.1.0" +dependencies = [ + "ark-crypto-primitives", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "ark-test-curves", + "blake2", + "blake3", + "clap", + "derivative", + "derive_more", + "goldilocks", + "itertools 0.14.0", + "lazy_static", + "nimue", + "nimue-pow", + "rand", + "rand_chacha", + "rayon", + "serde", + "serde_json", + "sha3", + "thiserror", + "transpose", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] diff --git a/whir/Cargo.toml b/whir/Cargo.toml new file mode 100644 index 000000000..66124c294 --- /dev/null +++ b/whir/Cargo.toml @@ -0,0 +1,56 @@ +[package] +categories = ["cryptography", "zk", "blockchain", "pcs"] +description = "Multilinear Polynomial Commitment Scheme" +edition = "2021" +keywords = ["cryptography", "zk", "blockchain", "pcs"] +license = "MIT OR Apache-2.0" +name = "whir" +readme = "README.md" +repository = "https://github.com/WizardOfMenlo/whir/" +version = "0.1.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +default-run = "main" + +[dependencies] +ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } +ark-ff = { version = "0.5", features = ["asm", "std"] } +ark-poly = "0.5" +ark-serialize = "0.5" +ark-std = { version = "0.5", features = ["std"] } +ark-test-curves = { version = "0.5", features = ["bls12_381_curve"] } +blake2 = "0.10" +blake3 = "1.5.0" +clap = { version = "4.4.17", features = ["derive"] } +derivative = { version = "2", features = ["use_core"] } +lazy_static = "1.4" +nimue = { git = "https://github.com/arkworks-rs/nimue", features = ["ark"] } +nimue-pow = { git = "https://github.com/arkworks-rs/nimue" } +rand = "0.8" +rand_chacha = "0.3" +rayon = { version = "1.10.0", optional = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +sha3 = "0.10" +transpose = "0.2.3" + +derive_more = { version = "1.0.0", features = ["debug"] } +goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } +itertools = "0.14.0" +thiserror = "1" + +[profile.release] +debug = true + +[features] +asm = [] +ceno = [] +default = ["parallel", "ceno"] +parallel = [ + "dep:rayon", + "ark-poly/parallel", + "ark-ff/parallel", + "ark-crypto-primitives/parallel", +] +print-trace = ["ark-std/print-trace"] +rayon = ["dep:rayon"] diff --git a/whir/LICENSE-APACHE b/whir/LICENSE-APACHE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/whir/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/whir/LICENSE-MIT b/whir/LICENSE-MIT new file mode 100644 index 000000000..c67929c8b --- /dev/null +++ b/whir/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Gal Arnon, Alessandro Chiesa, Giacomo Fenzi, Eylon Yogev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/whir/README.md b/whir/README.md new file mode 100644 index 000000000..ef91b05da --- /dev/null +++ b/whir/README.md @@ -0,0 +1,47 @@ +

WHIR 🌪️

+ +This library was developed using the [arkworks](https://arkworks.rs) ecosystem to accompany [WHIR 🌪️](https://eprint.iacr.org/2024/1586). +By [Gal Arnon](https://galarnon42.github.io/) [Alessandro Chiesa](https://ic-people.epfl.ch/~achiesa/), [Giacomo Fenzi](https://gfenzi.io), and [Eylon Yogev](https://www.eylonyogev.com/about). + +**WARNING:** This is an academic prototype and has not received careful code review. This implementation is NOT ready for production use. + +

+ + +

+ +# Usage +``` +cargo run --release -- --help + +Usage: main [OPTIONS] + +Options: + -t, --type [default: PCS] + -l, --security-level [default: 100] + -p, --pow-bits + -d, --num-variables [default: 20] + -e, --evaluations [default: 1] + -r, --rate [default: 1] + --reps [default: 1000] + -k, --fold [default: 4] + --sec [default: ConjectureList] + --fold_type [default: ProverHelps] + -f, --field [default: Goldilocks2] + --hash [default: Blake3] + -h, --help Print help + -V, --version Print version +``` + +Options: +- `-t` can be either `PCS` or `LDT` to run as a (multilinear) PCS or a LDT +- `-l` sets the (overall) security level of the scheme +- `-p` sets the number of PoW bits (used for the query-phase). PoW bits for proximity gaps are set automatically. +- `-d` sets the number of variables of the scheme. +- `-e` sets the number of evaluations to prove. Only meaningful in PCS mode. +- `-r` sets the log_inv of the rate +- `-k` sets the number of variables to fold at each iteration. +- `--sec` sets the settings used to compute security. Available `UniqueDecoding`, `ProvableList`, `ConjectureList` +- `--fold_type` sets the settings used to compute folds. Available `Naive`, `ProverHelps` +- `-f` sets the field used, available are `Goldilocks2, Goldilocks3, Field192, Field256`. +- `--hash` sets the hash used for the Merkle tree, available are `SHA3` and `Blake3` diff --git a/whir/rust-toolchain.toml b/whir/rust-toolchain.toml new file mode 100644 index 000000000..5d1274a3e --- /dev/null +++ b/whir/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2024-10-03" diff --git a/whir/src/bin/benchmark.rs b/whir/src/bin/benchmark.rs new file mode 100644 index 000000000..a9c388124 --- /dev/null +++ b/whir/src/bin/benchmark.rs @@ -0,0 +1,446 @@ +use std::{ + fs::OpenOptions, + time::{Duration, Instant}, +}; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, +}; +use ark_ff::{FftField, Field}; +use ark_serialize::CanonicalSerialize; +use nimue::{Arthur, DefaultHash, IOPattern, Merlin}; +use nimue_pow::blake3::Blake3PoW; +use whir::{ + cmdline_utils::{AvailableFields, AvailableMerkle}, + crypto::{ + fields, + merkle_tree::{self, HashCounter}, + }, + parameters::*, + poly_utils::coeffs::CoefficientList, + whir::Statement, +}; + +use serde::Serialize; + +use clap::Parser; +use whir::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short = 'l', long, default_value = "100")] + security_level: usize, + + #[arg(short = 'p', long)] + pow_bits: Option, + + #[arg(short = 'd', long, default_value = "20")] + num_variables: usize, + + #[arg(short = 'e', long = "evaluations", default_value = "1")] + num_evaluations: usize, + + #[arg(short = 'r', long, default_value = "1")] + rate: usize, + + #[arg(long = "reps", default_value = "1000")] + verifier_repetitions: usize, + + #[arg(short = 'i', long = "initfold", default_value = "4")] + first_round_folding_factor: usize, + + #[arg(short = 'k', long = "fold", default_value = "4")] + folding_factor: usize, + + #[arg(long = "sec", default_value = "ConjectureList")] + soundness_type: SoundnessType, + + #[arg(long = "fold_type", default_value = "ProverHelps")] + fold_optimisation: FoldType, + + #[arg(short = 'f', long = "field", default_value = "Goldilocks2")] + field: AvailableFields, + + #[arg(long = "hash", default_value = "Blake3")] + merkle_tree: AvailableMerkle, +} + +#[derive(Debug, Serialize)] +struct BenchmarkOutput { + security_level: usize, + pow_bits: usize, + starting_rate: usize, + num_variables: usize, + repetitions: usize, + folding_factor: usize, + soundness_type: SoundnessType, + field: AvailableFields, + merkle_tree: AvailableMerkle, + + // Whir + whir_evaluations: usize, + whir_argument_size: usize, + whir_prover_time: Duration, + whir_prover_hashes: usize, + whir_verifier_time: Duration, + whir_verifier_hashes: usize, + + // Whir LDT + whir_ldt_argument_size: usize, + whir_ldt_prover_time: Duration, + whir_ldt_prover_hashes: usize, + whir_ldt_verifier_time: Duration, + whir_ldt_verifier_hashes: usize, +} + +type PowStrategy = Blake3PoW; + +fn main() { + let mut args = Args::parse(); + let field = args.field; + let merkle = args.merkle_tree; + + if args.pow_bits.is_none() { + args.pow_bits = Some(default_max_pow(args.num_variables, args.rate)); + } + + let mut rng = ark_std::test_rng(); + + match (field, merkle) { + (AvailableFields::Goldilocks1, AvailableMerkle::Blake3) => { + use fields::Field64 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks1, AvailableMerkle::Keccak256) => { + use fields::Field64 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Blake3) => { + use fields::Field64_2 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Keccak256) => { + use fields::Field64_2 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Blake3) => { + use fields::Field64_3 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Keccak256) => { + use fields::Field64_3 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Blake3) => { + use fields::Field128 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Keccak256) => { + use fields::Field128 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Blake3) => { + use fields::Field192 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Keccak256) => { + use fields::Field192 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Blake3) => { + use fields::Field256 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Keccak256) => { + use fields::Field256 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + } +} + +fn run_whir( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + let security_level = args.security_level; + let pow_bits = args.pow_bits.unwrap(); + let num_variables = args.num_variables; + let starting_rate = args.rate; + let reps = args.verifier_repetitions; + let folding_factor = args.folding_factor; + let first_round_folding_factor = args.first_round_folding_factor; + let soundness_type = args.soundness_type; + let fold_optimisation = args.fold_optimisation; + + std::fs::create_dir_all("outputs").unwrap(); + + let num_coeffs = 1 << num_variables; + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level, + pow_bits, + folding_factor: FoldingFactor::ConstantFromSecondRound( + first_round_folding_factor, + folding_factor, + ), + leaf_hash_params, + two_to_one_params, + soundness_type, + fold_optimisation, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + let polynomial = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + + let ( + whir_ldt_prover_time, + whir_ldt_argument_size, + whir_ldt_prover_hashes, + whir_ldt_verifier_time, + whir_ldt_verifier_hashes, + ) = { + // Run LDT + use whir::whir::{ + committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, + verifier::Verifier, whir_proof_size, + }; + + let whir_params = WhirParameters:: { + initial_statement: false, + ..whir_params.clone() + }; + let params = + WhirConfig::::new(mv_params, whir_params.clone()); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + let whir_ldt_prover_time = Instant::now(); + + HashCounter::reset(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial.clone()).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, Statement::default(), witness) + .unwrap(); + + let whir_ldt_prover_time = whir_ldt_prover_time.elapsed(); + let whir_ldt_argument_size = whir_proof_size(merlin.transcript(), &proof); + let whir_ldt_prover_hashes = HashCounter::get(); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params); + + HashCounter::reset(); + let whir_ldt_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(merlin.transcript()); + verifier + .verify(&mut arthur, &Statement::default(), &proof) + .unwrap(); + } + + let whir_ldt_verifier_time = whir_ldt_verifier_time.elapsed(); + let whir_ldt_verifier_hashes = HashCounter::get() / reps; + + ( + whir_ldt_prover_time, + whir_ldt_argument_size, + whir_ldt_prover_hashes, + whir_ldt_verifier_time, + whir_ldt_verifier_hashes, + ) + }; + + let ( + whir_prover_time, + whir_argument_size, + whir_prover_hashes, + whir_verifier_time, + whir_verifier_hashes, + ) = { + // Run PCS + use whir::{ + poly_utils::MultilinearPoint, + whir::{ + committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, + prover::Prover, verifier::Verifier, whir_proof_size, + }, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + let points: Vec<_> = (0..args.num_evaluations) + .map(|i| MultilinearPoint(vec![F::from(i as u64); num_variables])) + .collect(); + let evaluations = points + .iter() + .map(|point| polynomial.evaluate_at_extension(point)) + .collect(); + let statement = Statement { + points, + evaluations, + }; + + HashCounter::reset(); + let whir_prover_time = Instant::now(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial.clone()).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, statement.clone(), witness) + .unwrap(); + + let whir_prover_time = whir_prover_time.elapsed(); + let whir_argument_size = whir_proof_size(merlin.transcript(), &proof); + let whir_prover_hashes = HashCounter::get(); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params); + + HashCounter::reset(); + let whir_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(merlin.transcript()); + verifier.verify(&mut arthur, &statement, &proof).unwrap(); + } + + let whir_verifier_time = whir_verifier_time.elapsed(); + let whir_verifier_hashes = HashCounter::get() / reps; + + ( + whir_prover_time, + whir_argument_size, + whir_prover_hashes, + whir_verifier_time, + whir_verifier_hashes, + ) + }; + + let output = BenchmarkOutput { + security_level, + pow_bits, + starting_rate, + num_variables, + repetitions: reps, + folding_factor, + soundness_type, + field: args.field, + merkle_tree: args.merkle_tree, + + // Whir + whir_evaluations: args.num_evaluations, + whir_prover_time, + whir_argument_size, + whir_prover_hashes, + whir_verifier_time, + whir_verifier_hashes, + + // Whir LDT + whir_ldt_prover_time, + whir_ldt_argument_size, + whir_ldt_prover_hashes, + whir_ldt_verifier_time, + whir_ldt_verifier_hashes, + }; + + let mut out_file = OpenOptions::new() + .append(true) + .create(true) + .open("outputs/bench_output.json") + .unwrap(); + use std::io::Write; + writeln!(out_file, "{}", serde_json::to_string(&output).unwrap()).unwrap(); +} diff --git a/whir/src/bin/main.rs b/whir/src/bin/main.rs new file mode 100644 index 000000000..0fc804d3a --- /dev/null +++ b/whir/src/bin/main.rs @@ -0,0 +1,438 @@ +use std::time::Instant; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, +}; +use ark_ff::FftField; +use ark_serialize::CanonicalSerialize; +use nimue::{Arthur, DefaultHash, IOPattern, Merlin}; +use whir::{ + cmdline_utils::{AvailableFields, AvailableMerkle, WhirType}, + crypto::{ + fields, + merkle_tree::{self, HashCounter}, + }, + parameters::*, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::Statement, +}; + +use nimue_pow::blake3::Blake3PoW; + +use clap::Parser; +use whir::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short = 't', long = "type", default_value = "PCS")] + protocol_type: WhirType, + + #[arg(short = 'l', long, default_value = "100")] + security_level: usize, + + #[arg(short = 'p', long)] + pow_bits: Option, + + #[arg(short = 'd', long, default_value = "20")] + num_variables: usize, + + #[arg(short = 'e', long = "evaluations", default_value = "1")] + num_evaluations: usize, + + #[arg(short = 'r', long, default_value = "1")] + rate: usize, + + #[arg(long = "reps", default_value = "1000")] + verifier_repetitions: usize, + + #[arg(short = 'i', long = "initfold", default_value = "4")] + first_round_folding_factor: usize, + + #[arg(short = 'k', long = "fold", default_value = "4")] + folding_factor: usize, + + #[arg(long = "sec", default_value = "ConjectureList")] + soundness_type: SoundnessType, + + #[arg(long = "fold_type", default_value = "ProverHelps")] + fold_optimisation: FoldType, + + #[arg(short = 'f', long = "field", default_value = "Goldilocks2")] + field: AvailableFields, + + #[arg(long = "hash", default_value = "Blake3")] + merkle_tree: AvailableMerkle, +} + +type PowStrategy = Blake3PoW; + +fn main() { + let mut args = Args::parse(); + let field = args.field; + let merkle = args.merkle_tree; + + if args.pow_bits.is_none() { + args.pow_bits = Some(default_max_pow(args.num_variables, args.rate)); + } + + let mut rng = ark_std::test_rng(); + + match (field, merkle) { + (AvailableFields::Goldilocks1, AvailableMerkle::Blake3) => { + use fields::Field64 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks1, AvailableMerkle::Keccak256) => { + use fields::Field64 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Blake3) => { + use fields::Field64_2 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks2, AvailableMerkle::Keccak256) => { + use fields::Field64_2 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Blake3) => { + use fields::Field64_3 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Goldilocks3, AvailableMerkle::Keccak256) => { + use fields::Field64_3 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Blake3) => { + use fields::Field128 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field128, AvailableMerkle::Keccak256) => { + use fields::Field128 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Blake3) => { + use fields::Field192 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field192, AvailableMerkle::Keccak256) => { + use fields::Field192 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Blake3) => { + use fields::Field256 as F; + use merkle_tree::blake3 as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + + (AvailableFields::Field256, AvailableMerkle::Keccak256) => { + use fields::Field256 as F; + use merkle_tree::keccak as mt; + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + run_whir::>(args, leaf_hash_params, two_to_one_params); + } + } +} + +fn run_whir( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + match args.protocol_type { + WhirType::PCS => run_whir_pcs::(args, leaf_hash_params, two_to_one_params), + WhirType::LDT => { + run_whir_as_ldt::(args, leaf_hash_params, two_to_one_params) + } + } +} + +fn run_whir_as_ldt( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + use whir::whir::{ + committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, + verifier::Verifier, + }; + + // Runs as a LDT + let security_level = args.security_level; + let pow_bits = args.pow_bits.unwrap(); + let num_variables = args.num_variables; + let starting_rate = args.rate; + let reps = args.verifier_repetitions; + let first_round_folding_factor = args.first_round_folding_factor; + let folding_factor = args.folding_factor; + let fold_optimisation = args.fold_optimisation; + let soundness_type = args.soundness_type; + + if args.num_evaluations > 1 { + println!("Warning: running as LDT but a number of evaluations to be proven was specified."); + } + + let num_coeffs = 1 << num_variables; + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: false, + security_level, + pow_bits, + folding_factor: FoldingFactor::ConstantFromSecondRound( + first_round_folding_factor, + folding_factor, + ), + leaf_hash_params, + two_to_one_params, + soundness_type, + fold_optimisation, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + let params = WhirConfig::::new(mv_params, whir_params.clone()); + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms); + + let mut merlin = io.to_merlin(); + + println!("========================================="); + println!("Whir (LDT) 🌪️"); + println!("Field: {:?} and MT: {:?}", args.field, args.merkle_tree); + println!("{}", params); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + use ark_ff::Field; + let polynomial = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + + let whir_prover_time = Instant::now(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, Statement::default(), witness) + .unwrap(); + + dbg!(whir_prover_time.elapsed()); + + // Serialize proof + let transcript = merlin.transcript().to_vec(); + let mut proof_bytes = vec![]; + proof.serialize_compressed(&mut proof_bytes).unwrap(); + + let proof_size = transcript.len() + proof_bytes.len(); + dbg!(proof_size); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params.clone()); + + HashCounter::reset(); + let whir_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(&transcript); + verifier + .verify(&mut arthur, &Statement::default(), &proof) + .unwrap(); + } + dbg!(whir_verifier_time.elapsed() / reps as u32); + dbg!(HashCounter::get() as f64 / reps as f64); +} + +fn run_whir_pcs( + args: Args, + leaf_hash_params: <::LeafHash as CRHScheme>::Parameters, + two_to_one_params: <::TwoToOneHash as TwoToOneCRHScheme>::Parameters, +) where + F: FftField + CanonicalSerialize, + MerkleConfig: Config + Clone, + MerkleConfig::InnerDigest: AsRef<[u8]> + From<[u8; 32]>, + IOPattern: DigestIOPattern, + Merlin: DigestWriter, + for<'a> Arthur<'a>: DigestReader, +{ + use whir::whir::{ + Statement, committer::Committer, iopattern::WhirIOPattern, parameters::WhirConfig, + prover::Prover, verifier::Verifier, whir_proof_size, + }; + + // Runs as a PCS + let security_level = args.security_level; + let pow_bits = args.pow_bits.unwrap(); + let num_variables = args.num_variables; + let starting_rate = args.rate; + let reps = args.verifier_repetitions; + let first_round_folding_factor = args.first_round_folding_factor; + let folding_factor = args.folding_factor; + let fold_optimisation = args.fold_optimisation; + let soundness_type = args.soundness_type; + let num_evaluations = args.num_evaluations; + + if num_evaluations == 0 { + println!("Warning: running as PCS but no evaluations specified."); + } + + let num_coeffs = 1 << num_variables; + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level, + pow_bits, + folding_factor: FoldingFactor::ConstantFromSecondRound( + first_round_folding_factor, + folding_factor, + ), + leaf_hash_params, + two_to_one_params, + soundness_type, + fold_optimisation, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + println!("========================================="); + println!("Whir (PCS) 🌪️"); + println!("Field: {:?} and MT: {:?}", args.field, args.merkle_tree); + println!("{}", params); + if !params.check_pow_bits() { + println!("WARN: more PoW bits required than what specified."); + } + + use ark_ff::Field; + let polynomial = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + let points: Vec<_> = (0..num_evaluations) + .map(|i| MultilinearPoint(vec![F::from(i as u64); num_variables])) + .collect(); + let evaluations = points + .iter() + .map(|point| polynomial.evaluate_at_extension(point)) + .collect(); + + let statement = Statement { + points, + evaluations, + }; + + let whir_prover_time = Instant::now(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, statement.clone(), witness) + .unwrap(); + + println!("Prover time: {:.1?}", whir_prover_time.elapsed()); + println!( + "Proof size: {:.1} KiB", + whir_proof_size(merlin.transcript(), &proof) as f64 / 1024.0 + ); + + // Just not to count that initial inversion (which could be precomputed) + let verifier = Verifier::new(params); + + HashCounter::reset(); + let whir_verifier_time = Instant::now(); + for _ in 0..reps { + let mut arthur = io.to_arthur(merlin.transcript()); + verifier.verify(&mut arthur, &statement, &proof).unwrap(); + } + println!( + "Verifier time: {:.1?}", + whir_verifier_time.elapsed() / reps as u32 + ); + println!( + "Average hashes: {:.1}k", + (HashCounter::get() as f64 / reps as f64) / 1000.0 + ); +} diff --git a/whir/src/ceno_binding/merkle_config.rs b/whir/src/ceno_binding/merkle_config.rs new file mode 100644 index 000000000..84eb804da --- /dev/null +++ b/whir/src/ceno_binding/merkle_config.rs @@ -0,0 +1,292 @@ +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use nimue::{Arthur, DefaultHash, IOPattern, Merlin, ProofResult}; +use nimue_pow::PowStrategy; + +use crate::{ + crypto::merkle_tree::{ + blake3::MerkleTreeParams as Blake3Params, keccak::MerkleTreeParams as KeccakParams, + }, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::{ + Statement, WhirProof, + batch::{WhirBatchIOPattern, Witnesses}, + committer::{Committer, Witness}, + fs_utils::DigestWriter, + iopattern::WhirIOPattern, + parameters::WhirConfig, + prover::Prover, + verifier::Verifier, + }, +}; + +pub trait WhirMerkleConfigWrapper { + type MerkleConfig: Config + Clone; + type PowStrategy: PowStrategy; + + fn commit_to_merlin( + committer: &Committer, + merlin: &mut Merlin, + poly: CoefficientList, + ) -> ProofResult>; + + fn commit_to_merlin_batch( + committer: &Committer, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult>; + + fn prove_with_merlin( + prover: &Prover, + merlin: &mut Merlin, + statement: Statement, + witness: Witness, + ) -> ProofResult>; + + fn prove_with_merlin_simple_batch( + prover: &Prover, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: &Witnesses, + ) -> ProofResult>; + + fn verify_with_arthur( + verifier: &Verifier, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest>; + + fn verify_with_arthur_simple_batch( + verifier: &Verifier, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest>; + + fn commit_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern; + + fn add_whir_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern; + + fn commit_batch_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern; + + fn add_whir_batch_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern; + + fn add_digest_to_merlin( + merlin: &mut Merlin, + digest: ::InnerDigest, + ) -> ProofResult<()>; +} + +pub struct Blake3ConfigWrapper(Blake3Params); +pub struct KeccakConfigWrapper(KeccakParams); + +impl WhirMerkleConfigWrapper for Blake3ConfigWrapper { + type MerkleConfig = Blake3Params; + type PowStrategy = nimue_pow::blake3::Blake3PoW; + fn commit_to_merlin( + committer: &Committer, + merlin: &mut Merlin, + poly: CoefficientList, + ) -> ProofResult> { + committer.commit(merlin, poly) + } + + fn commit_to_merlin_batch( + committer: &Committer, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> { + committer.batch_commit(merlin, polys) + } + + fn prove_with_merlin( + prover: &Prover, + merlin: &mut Merlin, + statement: Statement, + witness: Witness, + ) -> ProofResult> { + prover.prove(merlin, statement, witness) + } + + fn prove_with_merlin_simple_batch( + prover: &Prover, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: &Witnesses, + ) -> ProofResult> { + let points = [MultilinearPoint(point.to_vec())]; + prover.simple_batch_prove(merlin, &points, &[evals.to_vec()], witness) + } + + fn verify_with_arthur( + verifier: &Verifier, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + verifier.verify(arthur, statement, whir_proof) + } + + fn verify_with_arthur_simple_batch( + verifier: &Verifier, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + let points = [MultilinearPoint(point.to_vec())]; + verifier.simple_batch_verify(arthur, evals.len(), &points, &[evals.to_vec()], whir_proof) + } + + fn commit_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.commit_statement(params) + } + + fn add_whir_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.add_whir_proof(params) + } + + fn commit_batch_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.commit_batch_statement(params, batch_size) + } + + fn add_whir_batch_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.add_whir_batch_proof(params, batch_size) + } + + fn add_digest_to_merlin( + merlin: &mut Merlin, + digest: ::InnerDigest, + ) -> ProofResult<()> { + >::add_digest(merlin, digest) + } +} + +impl WhirMerkleConfigWrapper for KeccakConfigWrapper { + type MerkleConfig = KeccakParams; + type PowStrategy = nimue_pow::keccak::KeccakPoW; + fn commit_to_merlin( + committer: &Committer, + merlin: &mut Merlin, + poly: CoefficientList, + ) -> ProofResult> { + committer.commit(merlin, poly) + } + + fn commit_to_merlin_batch( + committer: &Committer, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> { + committer.batch_commit(merlin, polys) + } + + fn prove_with_merlin( + prover: &Prover, + merlin: &mut Merlin, + statement: Statement, + witness: Witness, + ) -> ProofResult> { + prover.prove(merlin, statement, witness) + } + + fn prove_with_merlin_simple_batch( + prover: &Prover, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: &Witnesses, + ) -> ProofResult> { + let points = [MultilinearPoint(point.to_vec())]; + prover.simple_batch_prove(merlin, &points, &[evals.to_vec()], witness) + } + + fn verify_with_arthur( + verifier: &Verifier, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + verifier.verify(arthur, statement, whir_proof) + } + + fn verify_with_arthur_simple_batch( + verifier: &Verifier, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<::InnerDigest> { + let points = [MultilinearPoint(point.to_vec())]; + verifier.simple_batch_verify(arthur, evals.len(), &points, &[evals.to_vec()], whir_proof) + } + + fn commit_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.commit_statement(params) + } + + fn add_whir_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + ) -> IOPattern { + iopattern.add_whir_proof(params) + } + + fn commit_batch_statement_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.commit_batch_statement(params, batch_size) + } + + fn add_whir_batch_proof_to_io_pattern( + iopattern: IOPattern, + params: &WhirConfig, + batch_size: usize, + ) -> IOPattern { + iopattern.add_whir_batch_proof(params, batch_size) + } + + fn add_digest_to_merlin( + merlin: &mut Merlin, + digest: ::InnerDigest, + ) -> ProofResult<()> { + >::add_digest(merlin, digest) + } +} diff --git a/whir/src/ceno_binding/mod.rs b/whir/src/ceno_binding/mod.rs new file mode 100644 index 000000000..7299cf99a --- /dev/null +++ b/whir/src/ceno_binding/mod.rs @@ -0,0 +1,80 @@ +mod merkle_config; +mod pcs; +pub use ark_crypto_primitives::merkle_tree::Config; +pub use pcs::{DefaultHash, InnerDigestOf, Whir, WhirDefaultSpec, WhirSpec}; + +use ark_ff::FftField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use serde::{Serialize, de::DeserializeOwned}; +use std::fmt::Debug; + +pub use nimue::{ + ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; + +#[derive(Debug, Clone, thiserror::Error)] +pub enum Error { + #[error(transparent)] + ProofError(#[from] nimue::ProofError), + #[error("CommitmentMismatchFromDigest")] + CommitmentMismatchFromDigest, + #[error("InvalidPcsParams")] + InvalidPcsParam, +} + +/// The trait for a non-interactive polynomial commitment scheme. +/// This trait serves as the intermediate step between WHIR and the +/// trait required in Ceno mpcs. Because Ceno and the WHIR implementation +/// in this crate assume different types of transcripts, to connect +/// them we can provide a non-interactive interface from WHIR. +pub trait PolynomialCommitmentScheme: Clone { + type Param: Clone + Debug + Serialize + DeserializeOwned; + type Commitment: Clone + Debug; + type CommitmentWithWitness: Clone + Debug; + type Proof: Clone + CanonicalSerialize + CanonicalDeserialize + Serialize + DeserializeOwned; + type Poly: Clone + Debug + Serialize + DeserializeOwned; + + fn setup(poly_size: usize) -> Self::Param; + + fn commit(pp: &Self::Param, poly: &Self::Poly) -> Result; + + fn batch_commit( + pp: &Self::Param, + polys: &[Self::Poly], + ) -> Result; + + fn open( + pp: &Self::Param, + comm: &Self::CommitmentWithWitness, + point: &[E], + eval: &E, + ) -> Result; + + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + pp: &Self::Param, + comm: &Self::CommitmentWithWitness, + point: &[E], + evals: &[E], + ) -> Result; + + fn verify( + vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + eval: &E, + proof: &Self::Proof, + ) -> Result<(), Error>; + + fn simple_batch_verify( + vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + proof: &Self::Proof, + ) -> Result<(), Error>; +} diff --git a/whir/src/ceno_binding/pcs.rs b/whir/src/ceno_binding/pcs.rs new file mode 100644 index 000000000..2c8f8ed80 --- /dev/null +++ b/whir/src/ceno_binding/pcs.rs @@ -0,0 +1,442 @@ +use super::{ + Error, PolynomialCommitmentScheme, + merkle_config::{Blake3ConfigWrapper, WhirMerkleConfigWrapper}, +}; +use crate::{ + crypto::merkle_tree::blake3::{self as mt}, + parameters::{ + FoldType, FoldingFactor, MultivariateParameters, SoundnessType, WhirParameters, + default_max_pow, + }, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::{ + Statement, WhirProof, batch::Witnesses, committer::Committer, parameters::WhirConfig, + prover::Prover, verifier::Verifier, + }, +}; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +pub use nimue::DefaultHash; +use nimue::{ + IOPattern, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::blake3::Blake3PoW; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use std::{ + fmt::{self, Debug, Formatter}, + marker::PhantomData, +}; + +pub trait WhirSpec: Default + std::fmt::Debug + Clone { + type MerkleConfigWrapper: WhirMerkleConfigWrapper; + fn get_parameters( + num_variables: usize, + for_batch: bool, + ) -> WhirParameters, PowOf>; + + fn prepare_whir_config( + num_variables: usize, + for_batch: bool, + ) -> WhirConfig, PowOf> { + let whir_params = Self::get_parameters(num_variables, for_batch); + let mv_params = MultivariateParameters::new(num_variables); + ConfigOf::::new(mv_params, whir_params) + } + + fn prepare_io_pattern(num_variables: usize) -> IOPattern { + let params = Self::prepare_whir_config(num_variables, false); + + let io = IOPattern::::new("🌪️"); + let io = >::commit_statement_to_io_pattern( + io, ¶ms, + ); + >::add_whir_proof_to_io_pattern( + io, ¶ms, + ) + } + + fn prepare_batch_io_pattern(num_variables: usize, batch_size: usize) -> IOPattern { + let params = Self::prepare_whir_config(num_variables, true); + + let io = IOPattern::::new("🌪️"); + let io = >::commit_batch_statement_to_io_pattern( + io, ¶ms, batch_size + ); + + >::add_whir_batch_proof_to_io_pattern( + io, ¶ms, batch_size + ) + } +} + +type MerkleConfigOf = + <>::MerkleConfigWrapper as WhirMerkleConfigWrapper>::MerkleConfig; +type ConfigOf = WhirConfig, PowOf>; + +pub type InnerDigestOf = as Config>::InnerDigest; + +type PowOf = + <>::MerkleConfigWrapper as WhirMerkleConfigWrapper>::PowStrategy; + +#[derive(Debug, Clone, Default)] +pub struct WhirDefaultSpec; + +impl WhirSpec for WhirDefaultSpec { + type MerkleConfigWrapper = Blake3ConfigWrapper; + fn get_parameters( + num_variables: usize, + for_batch: bool, + ) -> WhirParameters, Blake3PoW> { + let mut rng = ChaCha8Rng::from_seed([0u8; 32]); + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + WhirParameters::, Blake3PoW> { + initial_statement: true, + security_level: 100, + pow_bits: default_max_pow(num_variables, 1), + // For batching, the first round folding factor should be set small + // to avoid large leaf nodes in proof + folding_factor: if for_batch { + FoldingFactor::ConstantFromSecondRound(1, 4) + } else { + FoldingFactor::Constant(4) + }, + leaf_hash_params, + two_to_one_params, + soundness_type: SoundnessType::ConjectureList, + fold_optimisation: FoldType::ProverHelps, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct WhirSetupParams { + pub num_variables: usize, + _phantom: PhantomData, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Whir>(PhantomData<(E, Spec)>); + +// Wrapper for WhirProof +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + pub proof: WhirProof, + pub transcript: Vec, +} + +impl Serialize for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let proof = &self.proof.0; + // Create a buffer that implements the `Write` trait + let mut buffer = Vec::new(); + proof.serialize_compressed(&mut buffer).unwrap(); + let proof_size = buffer.len(); + let proof_size_bytes = proof_size.to_le_bytes(); + let mut data = proof_size_bytes.to_vec(); + data.extend_from_slice(&buffer); + data.extend_from_slice(&self.transcript); + serializer.serialize_bytes(&data) + } +} + +impl<'de, MerkleConfig, F> Deserialize<'de> for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let data: Vec = Deserialize::deserialize(deserializer)?; + let proof_size_bytes = &data[0..8]; + let proof_size = u64::from_le_bytes(proof_size_bytes.try_into().unwrap()); + let proof_bytes = &data[8..8 + proof_size as usize]; + let proof = WhirProof::deserialize_compressed(proof_bytes).unwrap(); + let transcript = data[8 + proof_size as usize..].to_vec(); + Ok(WhirProofWrapper { proof, transcript }) + } +} + +impl Debug for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("WhirProofWrapper") + } +} + +#[derive(Clone)] +pub struct CommitmentWithWitness +where + MerkleConfig: Config, +{ + pub commitment: MerkleConfig::InnerDigest, + pub witness: Witnesses, +} + +impl CommitmentWithWitness +where + MerkleConfig: Config, +{ + pub fn ood_answers(&self) -> Vec { + self.witness.ood_answers.clone() + } +} + +impl Debug for CommitmentWithWitness +where + MerkleConfig: Config, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("CommitmentWithWitness") + } +} + +impl> PolynomialCommitmentScheme for Whir +where + E: FftField + Serialize + DeserializeOwned + Debug, + E::BasePrimeField: Serialize + DeserializeOwned + Debug, +{ + type Param = (); + type Commitment = as Config>::InnerDigest; + type CommitmentWithWitness = CommitmentWithWitness>; + type Proof = WhirProofWrapper, E>; + type Poly = CoefficientList; + + fn setup(_poly_size: usize) -> Self::Param {} + + fn commit(_pp: &Self::Param, poly: &Self::Poly) -> Result { + let params = Spec::prepare_whir_config(poly.num_variables(), false); + + // The merlin here is just for satisfying the interface of + // WHIR, which only provides a commit_and_write function. + // It will be abandoned once this function finishes. + let io = Spec::prepare_io_pattern(poly.num_variables()); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params); + let witness = Witnesses::from(Spec::MerkleConfigWrapper::commit_to_merlin( + &committer, + &mut merlin, + poly.clone(), + )?); + + Ok(CommitmentWithWitness { + commitment: witness.merkle_tree.root(), + witness, + }) + } + + fn batch_commit( + _pp: &Self::Param, + polys: &[Self::Poly], + ) -> Result { + if polys.is_empty() { + return Err(Error::InvalidPcsParam); + } + + for i in 1..polys.len() { + if polys[i].num_variables() != polys[0].num_variables() { + return Err(Error::InvalidPcsParam); + } + } + + let params = Spec::prepare_whir_config(polys[0].num_variables(), true); + + // The merlin here is just for satisfying the interface of + // WHIR, which only provides a commit_and_write function. + // It will be abandoned once this function finishes. + let io = Spec::prepare_batch_io_pattern(polys[0].num_variables(), polys.len()); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params); + let witness = + Spec::MerkleConfigWrapper::commit_to_merlin_batch(&committer, &mut merlin, polys)?; + Ok(CommitmentWithWitness { + commitment: witness.merkle_tree.root(), + witness, + }) + } + + fn open( + _pp: &Self::Param, + witness: &Self::CommitmentWithWitness, + point: &[E], + eval: &E, + ) -> Result { + let params = Spec::prepare_whir_config(witness.witness.polys[0].num_variables(), false); + let io = Spec::prepare_io_pattern(witness.witness.polys[0].num_variables()); + let mut merlin = io.to_merlin(); + // In WHIR, the prover writes the commitment to the transcript, then + // the commitment is read from the transcript by the verifier, after + // the transcript is transformed into a arthur transcript. + // Here we repeat whatever the prover does. + // TODO: This is a hack. There should be a better design that does not + // require non-black-box knowledge of the inner working of WHIR. + + >::add_digest_to_merlin( + &mut merlin, + witness.commitment.clone(), + ) + .map_err(Error::ProofError)?; + let ood_answers = witness.ood_answers(); + if !ood_answers.is_empty() { + let mut ood_points = vec![::ZERO; ood_answers.len()]; + merlin + .fill_challenge_scalars(&mut ood_points) + .map_err(Error::ProofError)?; + merlin + .add_scalars(&ood_answers) + .map_err(Error::ProofError)?; + } + // Now the Merlin transcript is ready to pass to the verifier. + + let prover = Prover(params); + let statement = Statement { + points: vec![MultilinearPoint(point.to_vec())], + evaluations: vec![*eval], + }; + + let proof = Spec::MerkleConfigWrapper::prove_with_merlin( + &prover, + &mut merlin, + statement, + witness.witness.clone().into(), + )?; + + Ok(WhirProofWrapper { + proof, + transcript: merlin.transcript().to_vec(), + }) + } + + fn simple_batch_open( + _pp: &Self::Param, + witness: &Self::CommitmentWithWitness, + point: &[E], + evals: &[E], + ) -> Result { + let params = Spec::prepare_whir_config(witness.witness.polys[0].num_variables(), true); + let io = + Spec::prepare_batch_io_pattern(witness.witness.polys[0].num_variables(), evals.len()); + let mut merlin = io.to_merlin(); + // In WHIR, the prover writes the commitment to the transcript, then + // the commitment is read from the transcript by the verifier, after + // the transcript is transformed into a arthur transcript. + // Here we repeat whatever the prover does. + // TODO: This is a hack. There should be a better design that does not + // require non-black-box knowledge of the inner working of WHIR. + + >::add_digest_to_merlin( + &mut merlin, + witness.commitment.clone(), + ) + .map_err(Error::ProofError)?; + let ood_answers = witness.ood_answers(); + if !ood_answers.is_empty() { + let mut ood_points = + vec![::ZERO; ood_answers.len() / evals.len()]; + merlin + .fill_challenge_scalars(&mut ood_points) + .map_err(Error::ProofError)?; + merlin + .add_scalars(&ood_answers) + .map_err(Error::ProofError)?; + } + // Now the Merlin transcript is ready to pass to the verifier. + + let prover = Prover(params); + + let proof = Spec::MerkleConfigWrapper::prove_with_merlin_simple_batch( + &prover, + &mut merlin, + point, + evals, + &witness.witness, + )?; + + Ok(WhirProofWrapper { + proof, + transcript: merlin.transcript().to_vec(), + }) + } + + fn verify( + _vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + eval: &E, + proof: &Self::Proof, + ) -> Result<(), Error> { + let params = Spec::prepare_whir_config(point.len(), false); + let verifier = Verifier::new(params); + let io = Spec::prepare_io_pattern(point.len()); + let mut arthur = io.to_arthur(&proof.transcript); + + let statement = Statement { + points: vec![MultilinearPoint(point.to_vec())], + evaluations: vec![*eval], + }; + + let digest = Spec::MerkleConfigWrapper::verify_with_arthur( + &verifier, + &mut arthur, + &statement, + &proof.proof, + )?; + + if &digest != comm { + return Err(Error::CommitmentMismatchFromDigest); + } + + Ok(()) + } + + fn simple_batch_verify( + _vp: &Self::Param, + comm: &Self::Commitment, + point: &[E], + evals: &[E], + proof: &Self::Proof, + ) -> Result<(), Error> { + let params = Spec::prepare_whir_config(point.len(), true); + let verifier = Verifier::new(params); + let io = Spec::prepare_batch_io_pattern(point.len(), evals.len()); + let mut arthur = io.to_arthur(&proof.transcript); + + let digest = Spec::MerkleConfigWrapper::verify_with_arthur_simple_batch( + &verifier, + &mut arthur, + point, + evals, + &proof.proof, + )?; + + if &digest != comm { + return Err(Error::CommitmentMismatchFromDigest); + } + + Ok(()) + } +} diff --git a/whir/src/cmdline_utils.rs b/whir/src/cmdline_utils.rs new file mode 100644 index 000000000..9572fa328 --- /dev/null +++ b/whir/src/cmdline_utils.rs @@ -0,0 +1,75 @@ +use std::str::FromStr; + +use serde::Serialize; + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum WhirType { + LDT, + PCS, +} + +impl FromStr for WhirType { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "LDT" { + Ok(Self::LDT) + } else if s == "PCS" { + Ok(Self::PCS) + } else { + Err(format!("Invalid field: {}", s)) + } + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum AvailableFields { + Goldilocks1, // Just Goldilocks + Goldilocks2, // Quadratic extension of Goldilocks + Goldilocks3, // Cubic extension of Goldilocks + Field128, // 128-bit prime field + Field192, // 192-bit prime field + Field256, // 256-bit prime field +} + +impl FromStr for AvailableFields { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "Field128" { + Ok(Self::Field128) + } else if s == "Field192" { + Ok(Self::Field192) + } else if s == "Field256" { + Ok(Self::Field256) + } else if s == "Goldilocks1" { + Ok(Self::Goldilocks1) + } else if s == "Goldilocks2" { + Ok(Self::Goldilocks2) + } else if s == "Goldilocks3" { + Ok(Self::Goldilocks3) + } else { + Err(format!("Invalid field: {}", s)) + } + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum AvailableMerkle { + Keccak256, + Blake3, +} + +impl FromStr for AvailableMerkle { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "Keccak" { + Ok(Self::Keccak256) + } else if s == "Blake3" { + Ok(Self::Blake3) + } else { + Err(format!("Invalid hash: {}", s)) + } + } +} diff --git a/whir/src/crypto/fields.rs b/whir/src/crypto/fields.rs new file mode 100644 index 000000000..6fcb2f909 --- /dev/null +++ b/whir/src/crypto/fields.rs @@ -0,0 +1,92 @@ +use ark_ff::{ + Field, Fp2, Fp2Config, Fp3, Fp3Config, Fp64, Fp128, Fp192, Fp256, MontBackend, MontConfig, + MontFp, PrimeField, +}; + +pub trait FieldWithSize { + fn field_size_in_bits() -> usize; +} + +impl FieldWithSize for F +where + F: Field, +{ + fn field_size_in_bits() -> usize { + F::BasePrimeField::MODULUS_BIT_SIZE as usize * F::extension_degree() as usize + } +} + +#[derive(MontConfig)] +#[modulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"] +#[generator = "5"] +pub struct BN254Config; +pub type Field256 = Fp256>; + +#[derive(MontConfig)] +#[modulus = "3801539170989320091464968600173246866371124347557388484609"] +#[generator = "3"] +pub struct FConfig192; +pub type Field192 = Fp192>; + +#[derive(MontConfig)] +#[modulus = "340282366920938463463374557953744961537"] +#[generator = "3"] +pub struct FrConfig128; +pub type Field128 = Fp128>; + +#[derive(MontConfig)] +#[modulus = "18446744069414584321"] +#[generator = "7"] +pub struct FConfig64; +pub type Field64 = Fp64>; + +pub type Field64_2 = Fp2; +pub struct F2Config64; +impl Fp2Config for F2Config64 { + type Fp = Field64; + + const NONRESIDUE: Self::Fp = MontFp!("7"); + + const FROBENIUS_COEFF_FP2_C1: &'static [Self::Fp] = &[ + // Fq(7)**(((q^0) - 1) / 2) + MontFp!("1"), + // Fq(7)**(((q^1) - 1) / 2) + MontFp!("18446744069414584320"), + ]; +} + +pub type Field64_3 = Fp3; +pub struct F3Config64; + +impl Fp3Config for F3Config64 { + type Fp = Field64; + + const NONRESIDUE: Self::Fp = MontFp!("2"); + + const FROBENIUS_COEFF_FP3_C1: &'static [Self::Fp] = &[ + MontFp!("1"), + // Fq(2)^(((q^1) - 1) / 3) + MontFp!("4294967295"), + // Fq(2)^(((q^2) - 1) / 3) + MontFp!("18446744065119617025"), + ]; + + const FROBENIUS_COEFF_FP3_C2: &'static [Self::Fp] = &[ + MontFp!("1"), + // Fq(2)^(((2q^1) - 2) / 3) + MontFp!("18446744065119617025"), + // Fq(2)^(((2q^2) - 2) / 3) + MontFp!("4294967295"), + ]; + + // (q^3 - 1) = 2^32 * T where T = 1461501636310055817916238417282618014431694553085 + const TWO_ADICITY: u32 = 32; + + // 11^T + const QUADRATIC_NONRESIDUE_TO_T: Fp3 = + Fp3::new(MontFp!("5944137876247729999"), MontFp!("0"), MontFp!("0")); + + // T - 1 / 2 + const TRACE_MINUS_ONE_DIV_TWO: &'static [u64] = + &[0x80000002fffffffe, 0x80000002fffffffc, 0x7ffffffe]; +} diff --git a/whir/src/crypto/merkle_tree/blake3.rs b/whir/src/crypto/merkle_tree/blake3.rs new file mode 100644 index 000000000..2260f3984 --- /dev/null +++ b/whir/src/crypto/merkle_tree/blake3.rs @@ -0,0 +1,158 @@ +use std::{borrow::Borrow, marker::PhantomData}; + +use super::{HashCounter, IdentityDigestConverter}; +use crate::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, + sponge::Absorb, +}; +use ark_ff::Field; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use nimue::{ + Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult, +}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; + +#[derive( + Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, +)] +pub struct Blake3Digest([u8; 32]); + +impl AsRef<[u8]> for Blake3Digest { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl From<[u8; 32]> for Blake3Digest { + fn from(value: [u8; 32]) -> Self { + Self(value) + } +} + +impl Absorb for Blake3Digest { + fn to_sponge_bytes(&self, dest: &mut Vec) { + dest.extend_from_slice(&self.0); + } + + fn to_sponge_field_elements(&self, dest: &mut Vec) { + let mut buf = [0; 32]; + buf.copy_from_slice(&self.0); + dest.push(F::from_be_bytes_mod_order(&buf)); + } +} + +pub struct Blake3LeafHash(PhantomData); +pub struct Blake3TwoToOneCRHScheme; + +impl CRHScheme for Blake3LeafHash { + type Input = [F]; + type Output = Blake3Digest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + let mut buf = vec![]; + CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; + + let mut h = blake3::Hasher::new(); + h.update(&buf); + + let mut output = [0; 32]; + output.copy_from_slice(h.finalize().as_bytes()); + HashCounter::add(); + Ok(Blake3Digest(output)) + } +} + +impl TwoToOneCRHScheme for Blake3TwoToOneCRHScheme { + type Input = Blake3Digest; + type Output = Blake3Digest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + let mut h = blake3::Hasher::new(); + h.update(&left_input.borrow().0); + h.update(&right_input.borrow().0); + let mut output = [0; 32]; + output.copy_from_slice(h.finalize().as_bytes()); + HashCounter::add(); + Ok(Blake3Digest(output)) + } + + fn compress>( + parameters: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + ::evaluate(parameters, left_input, right_input) + } +} + +pub type LeafH = Blake3LeafHash; +pub type CompressH = Blake3TwoToOneCRHScheme; + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct MerkleTreeParams(PhantomData); + +impl Config for MerkleTreeParams { + type Leaf = [F]; + + type LeafDigest = as CRHScheme>::Output; + type LeafInnerDigestConverter = IdentityDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub fn default_config( + rng: &mut impl RngCore, +) -> ( + as CRHScheme>::Parameters, + ::Parameters, +) { + as CRHScheme>::setup(rng).unwrap(); + ::setup(rng).unwrap(); + + ((), ()) +} + +impl DigestIOPattern> for IOPattern { + fn add_digest(self, label: &str) -> Self { + self.add_bytes(32, label) + } +} + +impl DigestWriter> for Merlin { + fn add_digest(&mut self, digest: Blake3Digest) -> ProofResult<()> { + self.add_bytes(&digest.0).map_err(ProofError::InvalidIO) + } +} + +impl DigestReader> for Arthur<'_> { + fn read_digest(&mut self) -> ProofResult { + let mut digest = [0; 32]; + self.fill_next_bytes(&mut digest)?; + Ok(Blake3Digest(digest)) + } +} diff --git a/whir/src/crypto/merkle_tree/keccak.rs b/whir/src/crypto/merkle_tree/keccak.rs new file mode 100644 index 000000000..8bdd0af07 --- /dev/null +++ b/whir/src/crypto/merkle_tree/keccak.rs @@ -0,0 +1,158 @@ +use std::{borrow::Borrow, marker::PhantomData}; + +use super::{HashCounter, IdentityDigestConverter}; +use crate::whir::{ + fs_utils::{DigestReader, DigestWriter}, + iopattern::DigestIOPattern, +}; +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::Config, + sponge::Absorb, +}; +use ark_ff::Field; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use nimue::{ + Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult, +}; +use rand::RngCore; +use sha3::Digest; + +#[derive( + Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, +)] +pub struct KeccakDigest([u8; 32]); + +impl Absorb for KeccakDigest { + fn to_sponge_bytes(&self, dest: &mut Vec) { + dest.extend_from_slice(&self.0); + } + + fn to_sponge_field_elements(&self, dest: &mut Vec) { + let mut buf = [0; 32]; + buf.copy_from_slice(&self.0); + dest.push(F::from_be_bytes_mod_order(&buf)); + } +} + +impl From<[u8; 32]> for KeccakDigest { + fn from(value: [u8; 32]) -> Self { + KeccakDigest(value) + } +} + +impl AsRef<[u8]> for KeccakDigest { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +pub struct KeccakLeafHash(PhantomData); +pub struct KeccakTwoToOneCRHScheme; + +impl CRHScheme for KeccakLeafHash { + type Input = [F]; + type Output = KeccakDigest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + let mut buf = vec![]; + CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; + + let mut h = sha3::Keccak256::new(); + h.update(&buf); + + let mut output = [0; 32]; + output.copy_from_slice(&h.finalize()[..]); + HashCounter::add(); + Ok(KeccakDigest(output)) + } +} + +impl TwoToOneCRHScheme for KeccakTwoToOneCRHScheme { + type Input = KeccakDigest; + type Output = KeccakDigest; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + let mut h = sha3::Keccak256::new(); + h.update(left_input.borrow().0); + h.update(right_input.borrow().0); + let mut output = [0; 32]; + output.copy_from_slice(&h.finalize()[..]); + HashCounter::add(); + Ok(KeccakDigest(output)) + } + + fn compress>( + parameters: &Self::Parameters, + left_input: T, + right_input: T, + ) -> Result { + ::evaluate(parameters, left_input, right_input) + } +} + +pub type LeafH = KeccakLeafHash; +pub type CompressH = KeccakTwoToOneCRHScheme; + +#[derive(Debug, Default, Clone)] +pub struct MerkleTreeParams(PhantomData); + +impl Config for MerkleTreeParams { + type Leaf = [F]; + + type LeafDigest = as CRHScheme>::Output; + type LeafInnerDigestConverter = IdentityDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub fn default_config( + rng: &mut impl RngCore, +) -> ( + as CRHScheme>::Parameters, + ::Parameters, +) { + as CRHScheme>::setup(rng).unwrap(); + ::setup(rng).unwrap(); + + ((), ()) +} + +impl DigestIOPattern> for IOPattern { + fn add_digest(self, label: &str) -> Self { + self.add_bytes(32, label) + } +} + +impl DigestWriter> for Merlin { + fn add_digest(&mut self, digest: KeccakDigest) -> ProofResult<()> { + self.add_bytes(&digest.0).map_err(ProofError::InvalidIO) + } +} + +impl DigestReader> for Arthur<'_> { + fn read_digest(&mut self) -> ProofResult { + let mut digest = [0; 32]; + self.fill_next_bytes(&mut digest)?; + Ok(KeccakDigest(digest)) + } +} diff --git a/whir/src/crypto/merkle_tree/mock.rs b/whir/src/crypto/merkle_tree/mock.rs new file mode 100644 index 000000000..4becc42da --- /dev/null +++ b/whir/src/crypto/merkle_tree/mock.rs @@ -0,0 +1,67 @@ +use std::{borrow::Borrow, marker::PhantomData}; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::{ByteDigestConverter, Config}, +}; +use ark_serialize::CanonicalSerialize; +use rand::RngCore; + +pub struct Mock; + +impl TwoToOneCRHScheme for Mock { + type Input = [u8]; + type Output = Vec; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + _: T, + _: T, + ) -> Result { + Ok(vec![0u8; 32]) + } + + fn compress>( + _: &Self::Parameters, + _: T, + _: T, + ) -> Result { + Ok(vec![0u8; 32]) + } +} + +pub type LeafH = super::LeafIdentityHasher; +pub type CompressH = Mock; + +#[derive(Debug, Default)] +pub struct MerkleTreeParams(PhantomData); + +impl Config for MerkleTreeParams { + type Leaf = F; + + type LeafDigest = as CRHScheme>::Output; + type LeafInnerDigestConverter = ByteDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub fn default_config( + rng: &mut impl RngCore, +) -> ( + as CRHScheme>::Parameters, + ::Parameters, +) { + as CRHScheme>::setup(rng).unwrap(); + { + ::setup(rng).unwrap(); + }; + + ((), ()) +} diff --git a/whir/src/crypto/merkle_tree/mod.rs b/whir/src/crypto/merkle_tree/mod.rs new file mode 100644 index 000000000..092358ff5 --- /dev/null +++ b/whir/src/crypto/merkle_tree/mod.rs @@ -0,0 +1,73 @@ +pub mod blake3; +pub mod keccak; +pub mod mock; + +use std::{borrow::Borrow, marker::PhantomData, sync::atomic::AtomicUsize}; + +use ark_crypto_primitives::{Error, crh::CRHScheme, merkle_tree::DigestConverter}; +use ark_serialize::CanonicalSerialize; +use lazy_static::lazy_static; +use rand::RngCore; + +#[derive(Debug, Default)] +pub struct HashCounter { + counter: AtomicUsize, +} + +lazy_static! { + static ref HASH_COUNTER: HashCounter = HashCounter::default(); +} + +impl HashCounter { + pub(crate) fn add() -> usize { + HASH_COUNTER + .counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + } + + pub fn reset() { + HASH_COUNTER + .counter + .store(0, std::sync::atomic::Ordering::SeqCst) + } + + pub fn get() -> usize { + HASH_COUNTER + .counter + .load(std::sync::atomic::Ordering::SeqCst) + } +} + +#[derive(Debug, Default)] +pub struct LeafIdentityHasher(PhantomData); + +impl CRHScheme for LeafIdentityHasher { + type Input = F; + type Output = Vec; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + let mut buf = vec![]; + CanonicalSerialize::serialize_compressed(input.borrow(), &mut buf)?; + Ok(buf) + } +} + +/// A trivial converter where digest of previous layer's hash is the same as next layer's input. +pub struct IdentityDigestConverter { + _prev_layer_digest: T, +} + +impl DigestConverter for IdentityDigestConverter { + type TargetType = T; + fn convert(item: T) -> Result { + Ok(item) + } +} diff --git a/whir/src/crypto/mod.rs b/whir/src/crypto/mod.rs new file mode 100644 index 000000000..039b38f23 --- /dev/null +++ b/whir/src/crypto/mod.rs @@ -0,0 +1,2 @@ +pub mod fields; +pub mod merkle_tree; diff --git a/whir/src/domain.rs b/whir/src/domain.rs new file mode 100644 index 000000000..3c2bfc9b1 --- /dev/null +++ b/whir/src/domain.rs @@ -0,0 +1,144 @@ +use ark_ff::FftField; +use ark_poly::{ + EvaluationDomain, GeneralEvaluationDomain, MixedRadixEvaluationDomain, Radix2EvaluationDomain, +}; + +#[derive(Debug, Clone)] +pub struct Domain +where + F: FftField, +{ + pub base_domain: Option>, // The domain (in the base + // field) for the initial FFT + pub backing_domain: GeneralEvaluationDomain, +} + +impl Domain +where + F: FftField, +{ + pub fn new(degree: usize, log_rho_inv: usize) -> Option { + let size = degree * (1 << log_rho_inv); + let base_domain = GeneralEvaluationDomain::new(size)?; + let backing_domain = Self::to_extension_domain(&base_domain); + + Some(Self { + backing_domain, + base_domain: Some(base_domain), + }) + } + + // returns the size of the domain after folding folding_factor many times. + // + // This asserts that the domain size is divisible by 1 << folding_factor + pub fn folded_size(&self, folding_factor: usize) -> usize { + assert!(self.backing_domain.size() % (1 << folding_factor) == 0); + self.backing_domain.size() / (1 << folding_factor) + } + + pub fn size(&self) -> usize { + self.backing_domain.size() + } + + pub fn scale(&self, power: usize) -> Self { + Self { + backing_domain: self.scale_generator_by(power), + base_domain: None, // Set to zero because we only care for the initial + } + } + + fn to_extension_domain( + domain: &GeneralEvaluationDomain, + ) -> GeneralEvaluationDomain { + let group_gen = F::from_base_prime_field(domain.group_gen()); + let group_gen_inv = F::from_base_prime_field(domain.group_gen_inv()); + let size = domain.size() as u64; + let log_size_of_group = domain.log_size_of_group() as u32; + let size_as_field_element = F::from_base_prime_field(domain.size_as_field_element()); + let size_inv = F::from_base_prime_field(domain.size_inv()); + let offset = F::from_base_prime_field(domain.coset_offset()); + let offset_inv = F::from_base_prime_field(domain.coset_offset_inv()); + let offset_pow_size = F::from_base_prime_field(domain.coset_offset_pow_size()); + match domain { + GeneralEvaluationDomain::Radix2(_) => { + GeneralEvaluationDomain::Radix2(Radix2EvaluationDomain { + size, + log_size_of_group, + size_as_field_element, + size_inv, + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + GeneralEvaluationDomain::MixedRadix(_) => { + GeneralEvaluationDomain::MixedRadix(MixedRadixEvaluationDomain { + size, + log_size_of_group, + size_as_field_element, + size_inv, + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + } + } + + // Takes the underlying backing_domain = , and computes the new domain + // (note this will have size |L| / power) + fn scale_generator_by(&self, power: usize) -> GeneralEvaluationDomain { + let starting_size = self.size(); + assert_eq!(starting_size % power, 0); + let new_size = starting_size / power; + let log_size_of_group = new_size.trailing_zeros(); + let size_as_field_element = F::from(new_size as u64); + + match self.backing_domain { + GeneralEvaluationDomain::Radix2(r2) => { + let group_gen = r2.group_gen.pow([power as u64]); + let group_gen_inv = group_gen.inverse().unwrap(); + + let offset = r2.offset.pow([power as u64]); + let offset_inv = r2.offset_inv.pow([power as u64]); + let offset_pow_size = offset.pow([new_size as u64]); + + GeneralEvaluationDomain::Radix2(Radix2EvaluationDomain { + size: new_size as u64, + log_size_of_group, + size_as_field_element, + size_inv: size_as_field_element.inverse().unwrap(), + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + GeneralEvaluationDomain::MixedRadix(mr) => { + let group_gen = mr.group_gen.pow([power as u64]); + let group_gen_inv = mr.group_gen_inv.pow([power as u64]); + + let offset = mr.offset.pow([power as u64]); + let offset_inv = mr.offset_inv.pow([power as u64]); + let offset_pow_size = offset.pow([new_size as u64]); + + GeneralEvaluationDomain::MixedRadix(MixedRadixEvaluationDomain { + size: new_size as u64, + log_size_of_group, + size_as_field_element, + size_inv: size_as_field_element.inverse().unwrap(), + group_gen, + group_gen_inv, + offset, + offset_inv, + offset_pow_size, + }) + } + } + } +} diff --git a/whir/src/fs_utils.rs b/whir/src/fs_utils.rs new file mode 100644 index 000000000..9d9180415 --- /dev/null +++ b/whir/src/fs_utils.rs @@ -0,0 +1,38 @@ +use ark_ff::Field; +use nimue::plugins::ark::FieldIOPattern; +use nimue_pow::PoWIOPattern; +pub trait OODIOPattern { + fn add_ood(self, num_samples: usize) -> Self; +} + +impl OODIOPattern for IOPattern +where + F: Field, + IOPattern: FieldIOPattern, +{ + fn add_ood(self, num_samples: usize) -> Self { + if num_samples > 0 { + self.challenge_scalars(num_samples, "ood_query") + .add_scalars(num_samples, "ood_ans") + } else { + self + } + } +} + +pub trait WhirPoWIOPattern { + fn pow(self, bits: f64) -> Self; +} + +impl WhirPoWIOPattern for IOPattern +where + IOPattern: PoWIOPattern, +{ + fn pow(self, bits: f64) -> Self { + if bits > 0. { + self.challenge_pow("pow_queries") + } else { + self + } + } +} diff --git a/whir/src/lib.rs b/whir/src/lib.rs new file mode 100644 index 000000000..fe4ccb09a --- /dev/null +++ b/whir/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(dead_code)] +#[cfg(feature = "ceno")] +pub mod ceno_binding; // Connect whir with ceno +pub mod cmdline_utils; +pub mod crypto; // Crypto utils +pub mod domain; // Domain that we are evaluating over +pub mod fs_utils; +pub mod ntt; +pub mod parameters; +pub mod poly_utils; // Utils for polynomials +pub mod sumcheck; // Sumcheck specialised +pub mod utils; // Utils in general +pub mod whir; // The real prover diff --git a/whir/src/ntt/matrix.rs b/whir/src/ntt/matrix.rs new file mode 100644 index 000000000..0f7eb8780 --- /dev/null +++ b/whir/src/ntt/matrix.rs @@ -0,0 +1,187 @@ +//! Minimal matrix class that supports strided access. +//! This abstracts over the unsafe pointer arithmetic required for transpose-like algorithms. + +#![allow(unsafe_code)] + +use std::{ + marker::PhantomData, + ops::{Index, IndexMut}, + ptr, slice, +}; + +/// Mutable reference to a matrix. +/// +/// The invariant this data structure maintains is that `data` has lifetime +/// `'a` and points to a collection of `rows` rowws, at intervals `row_stride`, +/// each of length `cols`. +pub struct MatrixMut<'a, T> { + data: *mut T, + rows: usize, + cols: usize, + row_stride: usize, + _lifetime: PhantomData<&'a mut T>, +} + +unsafe impl Send for MatrixMut<'_, T> {} + +unsafe impl Sync for MatrixMut<'_, T> {} + +impl<'a, T> MatrixMut<'a, T> { + /// creates a MatrixMut from `slice`, where slice is the concatenations of `rows` rows, each consisting of `cols` many entries. + pub fn from_mut_slice(slice: &'a mut [T], rows: usize, cols: usize) -> Self { + assert_eq!(slice.len(), rows * cols); + // Safety: The input slice is valid for the lifetime `'a` and has + // `rows` contiguous rows of length `cols`. + Self { + data: slice.as_mut_ptr(), + rows, + cols, + row_stride: cols, + _lifetime: PhantomData, + } + } + + /// returns the number of rows + pub fn rows(&self) -> usize { + self.rows + } + + /// returns the number of columns + pub fn cols(&self) -> usize { + self.cols + } + + /// checks whether the matrix is a square matrix + pub fn is_square(&self) -> bool { + self.rows == self.cols + } + + /// returns a mutable reference to the `row`'th row of the MatrixMut + pub fn row(&mut self, row: usize) -> &mut [T] { + assert!(row < self.rows); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride` + // there is valid data of length `self.cols`. + unsafe { slice::from_raw_parts_mut(self.data.add(row * self.row_stride), self.cols) } + } + + /// Split the matrix into two vertically at the `row`'th row (meaning that in the returned pair (A,B), the matrix A has `row` rows). + /// + /// [A] + /// [ ] = self + /// [B] + pub fn split_vertical(self, row: usize) -> (Self, Self) { + assert!(row <= self.rows); + ( + Self { + data: self.data, + rows: row, + cols: self.cols, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + Self { + data: unsafe { self.data.add(row * self.row_stride) }, + rows: self.rows - row, + cols: self.cols, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + ) + } + + /// Split the matrix into two horizontally at the `col`th column (meaning that in the returned pair (A,B), the matrix A has `col` columns). + /// + /// [A B] = self + pub fn split_horizontal(self, col: usize) -> (Self, Self) { + assert!(col <= self.cols); + ( + // Safety: This reduces the number of cols, keeping all else the same. + Self { + data: self.data, + rows: self.rows, + cols: col, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + // Safety: This reduces the number of cols and offsets and, keeping all else the same. + Self { + data: unsafe { self.data.add(col) }, + rows: self.rows, + cols: self.cols - col, + row_stride: self.row_stride, + _lifetime: PhantomData, + }, + ) + } + + /// Split the matrix into four quadrants at the indicated `row` and `col` (meaning that in the returned 4-tuple (A,B,C,D), the matrix A is a `row`x`col` matrix) + /// + /// self = [A B] + /// [C D] + pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + let (u, l) = self.split_vertical(row); // split into upper and lower parts + let (a, b) = u.split_horizontal(col); + let (c, d) = l.split_horizontal(col); + (a, b, c, d) + } + + /// Swap two elements `a` and `b` in the matrix. + /// Each of `a`, `b` is given as (row,column)-pair. + /// If the given coordinates are out-of-bounds, the behaviour is undefined. + pub unsafe fn swap(&mut self, a: (usize, usize), b: (usize, usize)) { + if a != b { + unsafe { + let a = self.ptr_at_mut(a.0, a.1); + let b = self.ptr_at_mut(b.0, b.1); + ptr::swap_nonoverlapping(a, b, 1) + } + } + } + + /// returns an immutable pointer to the element at (`row`, `col`). This performs no bounds checking and provining indices out-of-bounds is UB. + unsafe fn ptr_at(&self, row: usize, col: usize) -> *const T { + // Safe to call under the following assertion (checked by caller) + // assert!(row < self.rows); + // assert!(col < self.cols); + + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + self.data.add(row * self.row_stride + col) + } + + /// returns a mutable pointer to the element at (`row`, `col`). This performs no bounds checking and provining indices out-of-bounds is UB. + unsafe fn ptr_at_mut(&mut self, row: usize, col: usize) -> *mut T { + // Safe to call under the following assertion (checked by caller) + // + // assert!(row < self.rows); + // assert!(col < self.cols); + + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + self.data.add(row * self.row_stride + col) + } +} + +// Use MatrixMut::ptr_at and MatrixMut::ptr_at_mut to implement Index and IndexMut. The latter are not unsafe, since they contain bounds-checks. + +impl Index<(usize, usize)> for MatrixMut<'_, T> { + type Output = T; + + fn index(&self, (row, col): (usize, usize)) -> &T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + unsafe { &*self.ptr_at(row, col) } + } +} + +impl IndexMut<(usize, usize)> for MatrixMut<'_, T> { + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T { + assert!(row < self.rows); + assert!(col < self.cols); + // Safety: The structure invariant guarantees that at offset `row * self.row_stride + col` + // there is valid data. + unsafe { &mut *self.ptr_at_mut(row, col) } + } +} diff --git a/whir/src/ntt/mod.rs b/whir/src/ntt/mod.rs new file mode 100644 index 000000000..f2386c38a --- /dev/null +++ b/whir/src/ntt/mod.rs @@ -0,0 +1,61 @@ +//! NTT and related algorithms. + +mod matrix; +mod ntt_impl; +mod transpose; +mod utils; +mod wavelet; + +use self::matrix::MatrixMut; +use ark_ff::FftField; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub use self::{ + ntt_impl::{intt, intt_batch, ntt, ntt_batch}, + transpose::{transpose, transpose_bench_allocate, transpose_test}, + wavelet::wavelet_transform, +}; + +/// RS encode at a rate 1/`expansion`. +pub fn expand_from_coeff(coeffs: &[F], expansion: usize) -> Vec { + let engine = ntt_impl::NttEngine::::new_from_cache(); + let expanded_size = coeffs.len() * expansion; + let mut result = Vec::with_capacity(expanded_size); + // Note: We can also zero-extend the coefficients and do a larger NTT. + // But this is more efficient. + + // Do coset NTT. + let root = engine.root(expanded_size); + result.extend_from_slice(coeffs); + #[cfg(not(feature = "parallel"))] + for i in 1..expansion { + let root = root.pow([i as u64]); + let mut offset = F::ONE; + result.extend(coeffs.iter().map(|x| { + let val = *x * offset; + offset *= root; + val + })); + } + #[cfg(feature = "parallel")] + result.par_extend((1..expansion).into_par_iter().flat_map(|i| { + let root_i = root.pow([i as u64]); + coeffs + .par_iter() + .enumerate() + .map_with(F::ZERO, move |root_j, (j, coeff)| { + if root_j.is_zero() { + *root_j = root_i.pow([j as u64]); + } else { + *root_j *= root_i; + } + *coeff * *root_j + }) + })); + + ntt_batch(&mut result, coeffs.len()); + transpose(&mut result, expansion, coeffs.len()); + result +} diff --git a/whir/src/ntt/ntt_impl.rs b/whir/src/ntt/ntt_impl.rs new file mode 100644 index 000000000..ceb482f81 --- /dev/null +++ b/whir/src/ntt/ntt_impl.rs @@ -0,0 +1,408 @@ +//! Number-theoretic transforms (NTTs) over fields with high two-adicity. +//! +//! Implements the √N Cooley-Tukey six-step algorithm to achieve parallelism with good locality. +//! A global cache is used for twiddle factors. + +use super::{ + transpose, + utils::{lcm, sqrt_factor, workload_size}, +}; +use ark_ff::{FftField, Field}; +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard}, +}; + +#[cfg(feature = "parallel")] +use {rayon::prelude::*, std::cmp::max}; + +/// Global cache for NTT engines, indexed by field. +// TODO: Skip `LazyLock` when `HashMap::with_hasher` becomes const. +// see https://github.com/rust-lang/rust/issues/102575 +static ENGINE_CACHE: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Enginge for computing NTTs over arbitrary fields. +/// Assumes the field has large two-adicity. +pub struct NttEngine { + order: usize, // order of omega_orger + omega_order: F, // primitive order'th root. + + // Roots of small order (zero if unavailable). The naming convention is that omega_foo has order foo. + half_omega_3_1_plus_2: F, // ½(ω₃ + ω₃²) + half_omega_3_1_min_2: F, // ½(ω₃ - ω₃²) + omega_4_1: F, + omega_8_1: F, + omega_8_3: F, + omega_16_1: F, + omega_16_3: F, + omega_16_9: F, + + // Root lookup table (extended on demand) + roots: RwLock>, +} + +/// Compute the NTT of a slice of field elements using a cached engine. +pub fn ntt(values: &mut [F]) { + NttEngine::::new_from_cache().ntt(values); +} + +/// Compute the many NTTs of size `size` using a cached engine. +pub fn ntt_batch(values: &mut [F], size: usize) { + NttEngine::::new_from_cache().ntt_batch(values, size); +} + +/// Compute the inverse NTT of a slice of field element without the 1/n scaling factor, using a cached engine. +pub fn intt(values: &mut [F]) { + NttEngine::::new_from_cache().intt(values); +} + +/// Compute the inverse NTT of multiple slice of field elements, each of size `size`, without the 1/n scaling factor and using a cached engine. +pub fn intt_batch(values: &mut [F], size: usize) { + NttEngine::::new_from_cache().intt_batch(values, size); +} + +impl NttEngine { + /// Get or create a cached engine for the field `F`. + pub fn new_from_cache() -> Arc { + let mut cache = ENGINE_CACHE.lock().unwrap(); + let type_id = TypeId::of::(); + if let Some(engine) = cache.get(&type_id) { + engine.clone().downcast::>().unwrap() + } else { + let engine = Arc::new(NttEngine::new_from_fftfield()); + cache.insert(type_id, engine.clone()); + engine + } + } + + /// Construct a new engine from the field's `FftField` trait. + fn new_from_fftfield() -> Self { + // TODO: Support SMALL_SUBGROUP + if F::TWO_ADICITY <= 63 { + Self::new(1 << F::TWO_ADICITY, F::TWO_ADIC_ROOT_OF_UNITY) + } else { + let mut generator = F::TWO_ADIC_ROOT_OF_UNITY; + for _ in 0..(F::TWO_ADICITY - 63) { + generator = generator.square(); + } + Self::new(1 << 63, generator) + } + } +} + +/// Creates a new NttEngine. `omega_order` must be a primitive root of unity of even order `omega`. +impl NttEngine { + pub fn new(order: usize, omega_order: F) -> Self { + assert!(order.trailing_zeros() > 0, "Order must be a multiple of 2."); + // TODO: Assert that omega factors into 2s and 3s. + assert_eq!(omega_order.pow([order as u64]), F::ONE); + assert_ne!(omega_order.pow([order as u64 / 2]), F::ONE); + let mut res = NttEngine { + order, + omega_order, + half_omega_3_1_plus_2: F::ZERO, + half_omega_3_1_min_2: F::ZERO, + omega_4_1: F::ZERO, + omega_8_1: F::ZERO, + omega_8_3: F::ZERO, + omega_16_1: F::ZERO, + omega_16_3: F::ZERO, + omega_16_9: F::ZERO, + roots: RwLock::new(Vec::new()), + }; + if order % 3 == 0 { + let omega_3_1 = res.root(3); + let omega_3_2 = omega_3_1 * omega_3_1; + // Note: char F cannot be 2 and so division by 2 works, because primitive roots of unity with even order exist. + res.half_omega_3_1_min_2 = (omega_3_1 - omega_3_2) / F::from(2u64); + res.half_omega_3_1_plus_2 = (omega_3_1 + omega_3_2) / F::from(2u64); + } + if order % 4 == 0 { + res.omega_4_1 = res.root(4); + } + if order % 8 == 0 { + res.omega_8_1 = res.root(8); + res.omega_8_3 = res.omega_8_1.pow([3]); + } + if order % 16 == 0 { + res.omega_16_1 = res.root(16); + res.omega_16_3 = res.omega_16_1.pow([3]); + res.omega_16_9 = res.omega_16_1.pow([9]); + } + res + } + + pub fn ntt(&self, values: &mut [F]) { + self.ntt_batch(values, values.len()) + } + + pub fn ntt_batch(&self, values: &mut [F], size: usize) { + assert!(values.len() % size == 0); + let roots = self.roots_table(size); + self.ntt_dispatch(values, &roots, size); + } + + /// Inverse NTT. Does not aply 1/n scaling factor. + pub fn intt(&self, values: &mut [F]) { + values[1..].reverse(); + self.ntt(values); + } + + /// Inverse batch NTT. Does not aply 1/n scaling factor. + pub fn intt_batch(&self, values: &mut [F], size: usize) { + assert!(values.len() % size == 0); + + #[cfg(not(feature = "parallel"))] + values.chunks_exact_mut(size).for_each(|values| { + values[1..].reverse(); + }); + + #[cfg(feature = "parallel")] + values.par_chunks_exact_mut(size).for_each(|values| { + values[1..].reverse(); + }); + + self.ntt_batch(values, size); + } + + pub fn root(&self, order: usize) -> F { + assert!( + self.order % order == 0, + "Subgroup of requested order does not exist." + ); + self.omega_order.pow([(self.order / order) as u64]) + } + + /// Returns a cached table of roots of unity of the given order. + fn roots_table(&self, order: usize) -> RwLockReadGuard> { + // Precompute more roots of unity if requested. + let roots = self.roots.read().unwrap(); + if roots.is_empty() || roots.len() % order != 0 { + // Obtain write lock to update the cache. + drop(roots); + let mut roots = self.roots.write().unwrap(); + // Race condition: check if another thread updated the cache. + if roots.is_empty() || roots.len() % order != 0 { + // Compute minimal size to support all sizes seen so far. + // TODO: Do we really need all of these? Can we leverage omege_2 = -1? + let size = if roots.is_empty() { + order + } else { + lcm(roots.len(), order) + }; + roots.clear(); + roots.reserve_exact(size); + + // Compute powers of roots of unity. + let root = self.root(size); + #[cfg(not(feature = "parallel"))] + { + let mut root_i = F::ONE; + for _ in 0..size { + roots.push(root_i); + root_i *= root; + } + } + #[cfg(feature = "parallel")] + roots.par_extend((0..size).into_par_iter().map_with(F::ZERO, |root_i, i| { + if root_i.is_zero() { + *root_i = root.pow([i as u64]); + } else { + *root_i *= root; + } + *root_i + })); + } + // Back to read lock. + drop(roots); + self.roots.read().unwrap() + } else { + roots + } + } + + /// Compute NTTs in place by splititng into two factors. + /// Recurses using the sqrt(N) Cooley-Tukey Six step NTT algorithm. + fn ntt_recurse(&self, values: &mut [F], roots: &[F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + let n1 = sqrt_factor(size); + let n2 = size / n1; + + transpose(values, n1, n2); + self.ntt_dispatch(values, roots, n1); + transpose(values, n2, n1); + // TODO: When (n1, n2) are coprime we can use the + // Good-Thomas NTT algorithm and avoid the twiddle loop. + Self::apply_twiddles(values, roots, n1, n2); + self.ntt_dispatch(values, roots, n2); + transpose(values, n1, n2); + } + + #[cfg(not(feature = "parallel"))] + fn apply_twiddles(&self, values: &mut [F], roots: &[F], rows: usize, cols: usize) { + debug_assert_eq!(values.len() % (rows * cols), 0); + let step = roots.len() / (rows * cols); + for values in values.chunks_exact_mut(rows * cols) { + for (i, row) in values.chunks_exact_mut(cols).enumerate().skip(1) { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + } + } + } + + #[cfg(feature = "parallel")] + fn apply_twiddles(values: &mut [F], roots: &[F], rows: usize, cols: usize) { + debug_assert_eq!(values.len() % (rows * cols), 0); + if values.len() > workload_size::() { + let size = rows * cols; + if values.len() != size { + let workload_size = size * max(1, workload_size::() / size); + values.par_chunks_mut(workload_size).for_each(|values| { + Self::apply_twiddles(values, roots, rows, cols); + }); + } else { + let step = roots.len() / (rows * cols); + values + .par_chunks_exact_mut(cols) + .enumerate() + .skip(1) + .for_each(|(i, row)| { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + }); + } + } else { + let step = roots.len() / (rows * cols); + for values in values.chunks_exact_mut(rows * cols) { + for (i, row) in values.chunks_exact_mut(cols).enumerate().skip(1) { + let step = (i * step) % roots.len(); + let mut index = step; + for value in row.iter_mut().skip(1) { + index %= roots.len(); + *value *= roots[index]; + index += step; + } + } + } + } + } + + fn ntt_dispatch(&self, values: &mut [F], roots: &[F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + debug_assert_eq!(roots.len() % size, 0); + #[cfg(feature = "parallel")] + if values.len() > workload_size::() && values.len() != size { + // Multiple NTTs, compute in parallel. + // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. + let workload_size = size * max(1, workload_size::() / size); + return values.par_chunks_mut(workload_size).for_each(|values| { + self.ntt_dispatch(values, roots, size); + }); + } + match size { + 0 | 1 => {} + 2 => { + for v in values.chunks_exact_mut(2) { + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + } + } + 3 => { + for v in values.chunks_exact_mut(3) { + // Rader NTT to reduce 3 to 2. + let v0 = v[0]; + (v[1], v[2]) = (v[1] + v[2], v[1] - v[2]); + v[0] += v[1]; + v[1] *= self.half_omega_3_1_plus_2; // ½(ω₃ + ω₃²) + v[2] *= self.half_omega_3_1_min_2; // ½(ω₃ - ω₃²) + v[1] += v0; + (v[1], v[2]) = (v[1] + v[2], v[1] - v[2]); + } + } + 4 => { + for v in values.chunks_exact_mut(4) { + (v[0], v[2]) = (v[0] + v[2], v[0] - v[2]); + (v[1], v[3]) = (v[1] + v[3], v[1] - v[3]); + v[3] *= self.omega_4_1; + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + (v[2], v[3]) = (v[2] + v[3], v[2] - v[3]); + (v[1], v[2]) = (v[2], v[1]); + } + } + 8 => { + for v in values.chunks_exact_mut(8) { + // Cooley-Tukey with v as 2x4 matrix. + (v[0], v[4]) = (v[0] + v[4], v[0] - v[4]); + (v[1], v[5]) = (v[1] + v[5], v[1] - v[5]); + (v[2], v[6]) = (v[2] + v[6], v[2] - v[6]); + (v[3], v[7]) = (v[3] + v[7], v[3] - v[7]); + v[5] *= self.omega_8_1; + v[6] *= self.omega_4_1; // == omega_8_2 + v[7] *= self.omega_8_3; + (v[0], v[2]) = (v[0] + v[2], v[0] - v[2]); + (v[1], v[3]) = (v[1] + v[3], v[1] - v[3]); + v[3] *= self.omega_4_1; + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + (v[2], v[3]) = (v[2] + v[3], v[2] - v[3]); + (v[4], v[6]) = (v[4] + v[6], v[4] - v[6]); + (v[5], v[7]) = (v[5] + v[7], v[5] - v[7]); + v[7] *= self.omega_4_1; + (v[4], v[5]) = (v[4] + v[5], v[4] - v[5]); + (v[6], v[7]) = (v[6] + v[7], v[6] - v[7]); + (v[1], v[4]) = (v[4], v[1]); + (v[3], v[6]) = (v[6], v[3]); + } + } + 16 => { + for v in values.chunks_exact_mut(16) { + // Cooley-Tukey with v as 4x4 matrix. + for i in 0..4 { + let v = &mut v[i..]; + (v[0], v[8]) = (v[0] + v[8], v[0] - v[8]); + (v[4], v[12]) = (v[4] + v[12], v[4] - v[12]); + v[12] *= self.omega_4_1; + (v[0], v[4]) = (v[0] + v[4], v[0] - v[4]); + (v[8], v[12]) = (v[8] + v[12], v[8] - v[12]); + (v[4], v[8]) = (v[8], v[4]); + } + v[5] *= self.omega_16_1; + v[6] *= self.omega_8_1; + v[7] *= self.omega_16_3; + v[9] *= self.omega_8_1; + v[10] *= self.omega_4_1; + v[11] *= self.omega_8_3; + v[13] *= self.omega_16_3; + v[14] *= self.omega_8_3; + v[15] *= self.omega_16_9; + for i in 0..4 { + let v = &mut v[i * 4..]; + (v[0], v[2]) = (v[0] + v[2], v[0] - v[2]); + (v[1], v[3]) = (v[1] + v[3], v[1] - v[3]); + v[3] *= self.omega_4_1; + (v[0], v[1]) = (v[0] + v[1], v[0] - v[1]); + (v[2], v[3]) = (v[2] + v[3], v[2] - v[3]); + (v[1], v[2]) = (v[2], v[1]); + } + (v[1], v[4]) = (v[4], v[1]); + (v[2], v[8]) = (v[8], v[2]); + (v[3], v[12]) = (v[12], v[3]); + (v[6], v[9]) = (v[9], v[6]); + (v[7], v[13]) = (v[13], v[7]); + (v[11], v[14]) = (v[14], v[11]); + } + } + size => self.ntt_recurse(values, roots, size), + } + } +} diff --git a/whir/src/ntt/transpose.rs b/whir/src/ntt/transpose.rs new file mode 100644 index 000000000..b07ad7497 --- /dev/null +++ b/whir/src/ntt/transpose.rs @@ -0,0 +1,550 @@ +use super::{super::utils::is_power_of_two, MatrixMut, utils::workload_size}; +use std::mem::swap; + +use ark_std::{end_timer, start_timer}; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +#[cfg(feature = "parallel")] +use rayon::join; + +// NOTE: The assumption that rows and cols are a power of two are actually only relevant for the square matrix case. +// (This is because the algorithm recurses into 4 sub-matrices of half dimension; we assume those to be square matrices as well, which only works for powers of two). + +/// Transpose a matrix in-place. +/// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols. +/// This algorithm assumes that both rows and cols are powers of two. +pub fn transpose(matrix: &mut [F], rows: usize, cols: usize) { + debug_assert_eq!(matrix.len() % (rows * cols), 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + debug_assert!(is_power_of_two(rows)); + debug_assert!(is_power_of_two(cols)); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let mut scratch = Vec::with_capacity(rows * cols); + #[allow(clippy::uninit_vec)] + unsafe { + scratch.set_len(rows * cols); + } + for matrix in matrix.chunks_exact_mut(rows * cols) { + scratch.copy_from_slice(matrix); + let src = MatrixMut::from_mut_slice(scratch.as_mut_slice(), rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + transpose_copy(src, dst); + } + } +} + +pub fn transpose_bench_allocate( + matrix: &mut [F], + rows: usize, + cols: usize, +) { + debug_assert_eq!(matrix.len() % (rows * cols), 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + debug_assert!(is_power_of_two(rows)); + debug_assert!(is_power_of_two(cols)); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let allocate_timer = start_timer!(|| "Allocate scratch."); + let mut scratch = Vec::with_capacity(rows * cols); + #[allow(clippy::uninit_vec)] + unsafe { + scratch.set_len(rows * cols); + } + end_timer!(allocate_timer); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let copy_timer = start_timer!(|| "Copy from slice."); + scratch.copy_from_slice(matrix); + end_timer!(copy_timer); + let src = MatrixMut::from_mut_slice(scratch.as_mut_slice(), rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + let transpose_copy_timer = start_timer!(|| "Transpose Copy."); + transpose_copy(src, dst); + end_timer!(transpose_copy_timer); + } + } +} + +pub fn transpose_test( + matrix: &mut [F], + rows: usize, + cols: usize, + buffer: &mut [F], +) { + debug_assert_eq!(matrix.len() % (rows * cols), 0); + // eprintln!( + // "Transpose {} x {rows} x {cols} matrix.", + // matrix.len() / (rows * cols) + // ); + if rows == cols { + debug_assert!(is_power_of_two(rows)); + debug_assert!(is_power_of_two(cols)); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let matrix = MatrixMut::from_mut_slice(matrix, rows, cols); + transpose_square(matrix); + } + } else { + let buffer = &mut buffer[0..rows * cols]; + // TODO: Special case for rows = 2 * cols and cols = 2 * rows. + // TODO: Special case for very wide matrices (e.g. n x 16). + let transpose_timer = start_timer!(|| "Transpose."); + for matrix in matrix.chunks_exact_mut(rows * cols) { + let copy_timer = start_timer!(|| "Copy from slice."); + // buffer.copy_from_slice(matrix); + buffer + .par_iter_mut() + .zip(matrix.par_iter_mut()) + .for_each(|(dst, src)| { + *dst = *src; + }); + end_timer!(copy_timer); + let transform_timer = start_timer!(|| "From mut slice."); + let src = MatrixMut::from_mut_slice(buffer, rows, cols); + let dst = MatrixMut::from_mut_slice(matrix, cols, rows); + end_timer!(transform_timer); + let transpose_copy_timer = start_timer!(|| "Transpose copy."); + transpose_copy(src, dst); + end_timer!(transpose_copy_timer); + } + end_timer!(transpose_timer); + } +} + +// The following function have both a parallel and a non-parallel implementation. +// We fuly split those in a parallel and a non-parallel functions (rather than using #[cfg] within a single function) +// and have main entry point fun that just calls the appropriate version (either fun_parallel or fun_not_parallel). +// The sole reason is that this simplifies unit tests: We otherwise would need to build twice to cover both cases. +// For effiency, we assume the compiler inlines away the extra "indirection" that we add to the entry point function. + +// NOTE: We could lift the Send constraints on non-parallel build. + +fn transpose_copy(src: MatrixMut, dst: MatrixMut) { + #[cfg(not(feature = "parallel"))] + transpose_copy_not_parallel(src, dst); + #[cfg(feature = "parallel")] + transpose_copy_parallel(src, dst); +} + +/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible. +#[cfg(feature = "parallel")] +fn transpose_copy_parallel( + src: MatrixMut<'_, F>, + mut dst: MatrixMut<'_, F>, +) { + assert_eq!(src.rows(), dst.cols()); + assert_eq!(src.cols(), dst.rows()); + if src.rows() * src.cols() > workload_size::() { + // Split along longest axis and recurse. + // This results in a cache-oblivious algorithm. + let ((a, b), (x, y)) = if src.rows() > src.cols() { + let n = src.rows() / 2; + (src.split_vertical(n), dst.split_horizontal(n)) + } else { + let n = src.cols() / 2; + (src.split_horizontal(n), dst.split_vertical(n)) + }; + join( + || transpose_copy_parallel(a, x), + || transpose_copy_parallel(b, y), + ); + } else { + for i in 0..src.rows() { + for j in 0..src.cols() { + dst[(j, i)] = src[(i, j)]; + } + } + } +} + +/// Sets `dst` to the transpose of `src`. This will panic if the sizes of `src` and `dst` are not compatible. +/// This is the non-parallel version +fn transpose_copy_not_parallel(src: MatrixMut<'_, F>, mut dst: MatrixMut<'_, F>) { + assert_eq!(src.rows(), dst.cols()); + assert_eq!(src.cols(), dst.rows()); + if src.rows() * src.cols() > workload_size::() { + // Split along longest axis and recurse. + // This results in a cache-oblivious algorithm. + let ((a, b), (x, y)) = if src.rows() > src.cols() { + let n = src.rows() / 2; + (src.split_vertical(n), dst.split_horizontal(n)) + } else { + let n = src.cols() / 2; + (src.split_horizontal(n), dst.split_vertical(n)) + }; + transpose_copy_not_parallel(a, x); + transpose_copy_not_parallel(b, y); + } else { + for i in 0..src.rows() { + for j in 0..src.cols() { + dst[(j, i)] = src[(i, j)]; + } + } + } +} + +/// Transpose a square matrix in-place. Asserts that the size of the matrix is a power of two. +fn transpose_square(m: MatrixMut) { + #[cfg(feature = "parallel")] + transpose_square_parallel(m); + #[cfg(not(feature = "parallel"))] + transpose_square_non_parallel(m); +} + +/// Transpose a square matrix in-place. Asserts that the size of the matrix is a power of two. +/// This is the parallel version. +#[cfg(feature = "parallel")] +fn transpose_square_parallel(mut m: MatrixMut) { + debug_assert!(m.is_square()); + debug_assert!(m.rows().is_power_of_two()); + let size = m.rows(); + if size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (a, b, c, d) = m.split_quadrants(n, n); + + join( + || transpose_square_swap_parallel(b, c), + || { + join( + || transpose_square_parallel(a), + || transpose_square_parallel(d), + ) + }, + ); + } else { + for i in 0..size { + for j in (i + 1)..size { + // unsafe needed due to lack of bounds-check by swap. We are guaranteed that (i,j) and (j,i) are within the bounds. + unsafe { + m.swap((i, j), (j, i)); + } + } + } + } +} + +/// Transpose a square matrix in-place. Asserts that the size of the matrix is a power of two. +/// This is the non-parallel version. +fn transpose_square_non_parallel(mut m: MatrixMut) { + debug_assert!(m.is_square()); + debug_assert!(m.rows().is_power_of_two()); + let size = m.rows(); + if size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (a, b, c, d) = m.split_quadrants(n, n); + transpose_square_non_parallel(a); + transpose_square_non_parallel(d); + transpose_square_swap_non_parallel(b, c); + } else { + for i in 0..size { + for j in (i + 1)..size { + // unsafe needed due to lack of bounds-check by swap. We are guaranteed that (i,j) and (j,i) are within the bounds. + unsafe { + m.swap((i, j), (j, i)); + } + } + } + } +} + +/// Transpose and swap two square size matrices. Sizes must be equal and a power of two. +fn transpose_square_swap(a: MatrixMut, b: MatrixMut) { + #[cfg(feature = "parallel")] + transpose_square_swap_parallel(a, b); + #[cfg(not(feature = "parallel"))] + transpose_square_swap_non_parallel(a, b); +} + +/// Transpose and swap two square size matrices (parallel version). The size must be a power of two. +#[cfg(feature = "parallel")] +fn transpose_square_swap_parallel(mut a: MatrixMut, mut b: MatrixMut) { + debug_assert!(a.is_square()); + debug_assert_eq!(a.rows(), b.cols()); + debug_assert_eq!(a.cols(), b.rows()); + debug_assert!(is_power_of_two(a.rows())); + debug_assert!(workload_size::() >= 2); // otherwise, we would recurse even if size == 1. + let size = a.rows(); + if 2 * size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (aa, ab, ac, ad) = a.split_quadrants(n, n); + let (ba, bb, bc, bd) = b.split_quadrants(n, n); + + join( + || { + join( + || transpose_square_swap_parallel(aa, ba), + || transpose_square_swap_parallel(ab, bc), + ) + }, + || { + join( + || transpose_square_swap_parallel(ac, bb), + || transpose_square_swap_parallel(ad, bd), + ) + }, + ); + } else { + for i in 0..size { + for j in 0..size { + swap(&mut a[(i, j)], &mut b[(j, i)]) + } + } + } +} + +/// Transpose and swap two square size matrices, whose sizes are a power of two (non-parallel version) +fn transpose_square_swap_non_parallel(mut a: MatrixMut, mut b: MatrixMut) { + debug_assert!(a.is_square()); + debug_assert_eq!(a.rows(), b.cols()); + debug_assert_eq!(a.cols(), b.rows()); + debug_assert!(is_power_of_two(a.rows())); + debug_assert!(workload_size::() >= 2); // otherwise, we would recurse even if size == 1. + + let size = a.rows(); + if 2 * size * size > workload_size::() { + // Recurse into quadrants. + // This results in a cache-oblivious algorithm. + let n = size / 2; + let (aa, ab, ac, ad) = a.split_quadrants(n, n); + let (ba, bb, bc, bd) = b.split_quadrants(n, n); + transpose_square_swap_non_parallel(aa, ba); + transpose_square_swap_non_parallel(ab, bc); + transpose_square_swap_non_parallel(ac, bb); + transpose_square_swap_non_parallel(ad, bd); + } else { + for i in 0..size { + for j in 0..size { + swap(&mut a[(i, j)], &mut b[(j, i)]) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{super::utils::workload_size, *}; + + type Pair = (usize, usize); + type Triple = (usize, usize, usize); + + // create a vector (intended to be viewed as a matrix) whose (i,j)'th entry is the pair (i,j) itself. + // This is useful to debug transposition algorithms. + fn make_example_matrix(rows: usize, columns: usize) -> Vec { + let mut v: Vec = vec![(0, 0); rows * columns]; + let mut view = MatrixMut::from_mut_slice(&mut v, rows, columns); + for i in 0..rows { + for j in 0..columns { + view[(i, j)] = (i, j); + } + } + v + } + + // create a vector (intended to be viewed as a sequence of `instances` of matrices) where (i,j)'th entry of the `index`th matrix + // is the triple (index, i,j). + fn make_example_matrices(rows: usize, columns: usize, instances: usize) -> Vec { + let mut v: Vec = vec![(0, 0, 0); rows * columns * instances]; + for index in 0..instances { + let mut view = MatrixMut::from_mut_slice( + &mut v[rows * columns * index..rows * columns * (index + 1)], + rows, + columns, + ); + for i in 0..rows { + for j in 0..columns { + view[(i, j)] = (index, i, j); + } + } + } + v + } + + #[test] + fn test_transpose_copy() { + // iterate over both parallel and non-parallel implementation. + // Needs HRTB, otherwise it won't work. + #[allow(clippy::type_complexity)] + let mut funs: Vec<&dyn for<'a, 'b> Fn(MatrixMut<'a, Pair>, MatrixMut<'b, Pair>)> = vec![ + &transpose_copy_not_parallel::, + &transpose_copy::, + ]; + #[cfg(feature = "parallel")] + funs.push(&transpose_copy_parallel::); + + for f in funs { + let rows: usize = workload_size::() + 1; // intentionally not a power of two: The function is not described as only working for powers of two. + let columns: usize = 4; + let mut srcarray = make_example_matrix(rows, columns); + let mut dstarray: Vec<(usize, usize)> = vec![(0, 0); rows * columns]; + + let src1 = MatrixMut::::from_mut_slice(&mut srcarray[..], rows, columns); + let dst1 = MatrixMut::::from_mut_slice(&mut dstarray[..], columns, rows); + + f(src1, dst1); + let dst1 = MatrixMut::::from_mut_slice(&mut dstarray[..], columns, rows); + + for i in 0..rows { + for j in 0..columns { + assert_eq!(dst1[(j, i)], (i, j)); + } + } + } + } + + #[test] + fn test_transpose_square_swap() { + // iterate over parallel and non-parallel variants: + #[allow(clippy::type_complexity)] + let mut funs: Vec<&dyn for<'a> Fn(MatrixMut<'a, Triple>, MatrixMut<'a, Triple>)> = vec![ + &transpose_square_swap::, + &transpose_square_swap_non_parallel::, + ]; + #[cfg(feature = "parallel")] + funs.push(&transpose_square_swap_parallel::); + + for f in funs { + // Set rows manually. We want to be sure to trigger the actual recursion. + // (Computing this from workload_size was too much hassle.) + let rows = 1024; // workload_size::(); + assert!(rows * rows > 2 * workload_size::()); + + let examples: Vec = make_example_matrices(rows, rows, 2); + // Make copies for simplicity, because we borrow different parts. + let mut examples1 = Vec::from(&examples[0..rows * rows]); + let mut examples2 = Vec::from(&examples[rows * rows..2 * rows * rows]); + + let view1 = MatrixMut::from_mut_slice(&mut examples1, rows, rows); + let view2 = MatrixMut::from_mut_slice(&mut examples2, rows, rows); + for i in 0..rows { + for j in 0..rows { + assert_eq!(view1[(i, j)], (0, i, j)); + assert_eq!(view2[(i, j)], (1, i, j)); + } + } + f(view1, view2); + let view1 = MatrixMut::from_mut_slice(&mut examples1, rows, rows); + let view2 = MatrixMut::from_mut_slice(&mut examples2, rows, rows); + for i in 0..rows { + for j in 0..rows { + assert_eq!(view1[(i, j)], (1, j, i)); + assert_eq!(view2[(i, j)], (0, j, i)); + } + } + } + } + + #[test] + fn test_transpose_square() { + let mut funs: Vec<&dyn for<'a> Fn(MatrixMut<'a, _>)> = vec![ + &transpose_square::, + &transpose_square_parallel::, + ]; + #[cfg(feature = "parallel")] + funs.push(&transpose_square::); + for f in funs { + // Set rows manually. We want to be sure to trigger the actual recursion. + // (Computing this from workload_size was too much hassle.) + let size = 1024; + assert!(size * size > 2 * workload_size::()); + + let mut example = make_example_matrix(size, size); + let view = MatrixMut::from_mut_slice(&mut example, size, size); + f(view); + let view = MatrixMut::from_mut_slice(&mut example, size, size); + for i in 0..size { + for j in 0..size { + assert_eq!(view[(i, j)], (j, i)); + } + } + } + } + + #[test] + fn test_transpose() { + let size = 1024; + + // rectangular matrix: + let rows = size; + let cols = 16; + let mut example = make_example_matrix(rows, cols); + transpose(&mut example, rows, cols); + let view = MatrixMut::from_mut_slice(&mut example, cols, rows); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (j, i)); + } + } + + // square matrix: + let rows = size; + let cols = size; + let mut example = make_example_matrix(rows, cols); + transpose(&mut example, rows, cols); + let view = MatrixMut::from_mut_slice(&mut example, cols, rows); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (j, i)); + } + } + + // 20 rectangular matrices: + let number_of_matrices = 20; + let rows = size; + let cols = 16; + let mut example = make_example_matrices(rows, cols, number_of_matrices); + transpose(&mut example, rows, cols); + for index in 0..number_of_matrices { + let view = MatrixMut::from_mut_slice( + &mut example[index * rows * cols..(index + 1) * rows * cols], + cols, + rows, + ); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (index, j, i)); + } + } + } + + // 20 square matrices: + let number_of_matrices = 20; + let rows = size; + let cols = size; + let mut example = make_example_matrices(rows, cols, number_of_matrices); + transpose(&mut example, rows, cols); + for index in 0..number_of_matrices { + let view = MatrixMut::from_mut_slice( + &mut example[index * rows * cols..(index + 1) * rows * cols], + cols, + rows, + ); + for i in 0..cols { + for j in 0..rows { + assert_eq!(view[(i, j)], (index, j, i)); + } + } + } + } +} diff --git a/whir/src/ntt/utils.rs b/whir/src/ntt/utils.rs new file mode 100644 index 000000000..c2d7f18f5 --- /dev/null +++ b/whir/src/ntt/utils.rs @@ -0,0 +1,143 @@ +/// Target single-thread workload size for `T`. +/// Should ideally be a multiple of a cache line (64 bytes) +/// and close to the L1 cache size (32 KB). +pub const fn workload_size() -> usize { + const CACHE_SIZE: usize = 1 << 15; + CACHE_SIZE / size_of::() +} + +/// Cast a slice into chunks of size N. +/// +/// TODO: Replace with `slice::as_chunks` when stable. +pub fn as_chunks_exact_mut(slice: &mut [T]) -> &mut [[T; N]] { + assert!(N != 0, "chunk size must be non-zero"); + assert_eq!( + slice.len() % N, + 0, + "slice length must be a multiple of chunk size" + ); + // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length + let new_len = slice.len() / N; + // SAFETY: We cast a slice of `new_len * N` elements into + // a slice of `new_len` many `N` elements chunks. + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) } +} + +/// Compute the largest factor of n that is <= sqrt(n). +/// Assumes n is of the form 2^k * {1,3,9}. +pub fn sqrt_factor(n: usize) -> usize { + let twos = n.trailing_zeros(); + match n >> twos { + 1 => 1 << (twos / 2), + 3 | 9 => 3 << (twos / 2), + _ => panic!(), + } +} + +/// Least common multiple. +/// +/// Note that lcm(0,0) will panic (rather than give the correct answer 0). +pub fn lcm(a: usize, b: usize) -> usize { + a * (b / gcd(a, b)) +} + +/// Greatest common divisor. +pub fn gcd(mut a: usize, mut b: usize) -> usize { + while b != 0 { + (a, b) = (b, a % b); + } + a +} + +#[cfg(test)] +mod tests { + use super::{as_chunks_exact_mut, gcd, lcm, sqrt_factor}; + + #[test] + fn test_gcd() { + assert_eq!(gcd(4, 6), 2); + assert_eq!(gcd(0, 4), 4); + assert_eq!(gcd(4, 0), 4); + assert_eq!(gcd(1, 1), 1); + assert_eq!(gcd(64, 16), 16); + assert_eq!(gcd(81, 9), 9); + assert_eq!(gcd(0, 0), 0); + } + + #[test] + fn test_lcm() { + assert_eq!(lcm(5, 6), 30); + assert_eq!(lcm(3, 7), 21); + assert_eq!(lcm(0, 10), 0); + } + #[test] + fn test_sqrt_factor() { + // naive brute-force search for largest divisor up to sqrt n. + // This is not supposed to be efficient, but optimized for "ease of convincing yourself it's correct (provided none of the asserts trigger)". + fn get_largest_divisor_up_to_sqrt(x: usize) -> usize { + if x == 0 { + return 0; + } + let mut result = 1; + let isqrt_of_x: usize = { + // use x.isqrt() once this is stabilized. That would be MUCH simpler. + + assert!(x < (1 << f64::MANTISSA_DIGITS)); // guarantees that each of {x, floor(sqrt(x)), ceil(sqrt(x))} can be represented exactly by f64. + let x_as_float = x as f64; + // sqrt is guaranteed to be the exact result, then rounded. Due to the above assert, the rounded value is between floor(sqrt(x)) and ceil(sqrt(x)). + let sqrt_x = x_as_float.sqrt(); + // We return sqrt_x, rounded to 0; for correctness, we need to rule out that we rounded from a non-integer up to the integer ceil(sqrt(x)). + if sqrt_x.fract() == 0.0 { + assert!(sqrt_x * sqrt_x == x_as_float); + } + unsafe { sqrt_x.to_int_unchecked() } + }; + for i in 1..=isqrt_of_x { + if x % i == 0 { + result = i; + } + } + result + } + + for i in 0..10 { + assert_eq!(sqrt_factor(1 << i), get_largest_divisor_up_to_sqrt(1 << i)); + } + for i in 0..10 { + assert_eq!(sqrt_factor(1 << i), get_largest_divisor_up_to_sqrt(1 << i)); + } + + for i in 0..10 { + assert_eq!(sqrt_factor(1 << i), get_largest_divisor_up_to_sqrt(1 << i)); + } + } + + #[test] + fn test_as_chunks_exact_mut() { + let v = &mut [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + assert_eq!(as_chunks_exact_mut::<_, 12>(v), &[[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 + ]]); + assert_eq!(as_chunks_exact_mut::<_, 6>(v), &[[1, 2, 3, 4, 5, 6], [ + 7, 8, 9, 10, 11, 12 + ]]); + assert_eq!(as_chunks_exact_mut::<_, 1>(v), &[ + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9], + [10], + [11], + [12] + ]); + let should_not_work = std::panic::catch_unwind(|| { + as_chunks_exact_mut::<_, 2>(&mut [1, 2, 3]); + }); + assert!(should_not_work.is_err()) + } +} diff --git a/whir/src/ntt/wavelet.rs b/whir/src/ntt/wavelet.rs new file mode 100644 index 000000000..88ed2aea6 --- /dev/null +++ b/whir/src/ntt/wavelet.rs @@ -0,0 +1,90 @@ +use super::{transpose, utils::workload_size}; +use ark_ff::Field; +use std::cmp::max; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Fast Wavelet Transform. +/// +/// The input slice must have a length that is a power of two. +/// Recursively applies the kernel +/// [1 0] +/// [1 1] +pub fn wavelet_transform(values: &mut [F]) { + debug_assert!(values.len().is_power_of_two()); + wavelet_transform_batch(values, values.len()) +} + +pub fn wavelet_transform_batch(values: &mut [F], size: usize) { + debug_assert_eq!(values.len() % size, 0); + debug_assert!(size.is_power_of_two()); + #[cfg(feature = "parallel")] + if values.len() > workload_size::() && values.len() != size { + // Multiple wavelet transforms, compute in parallel. + // Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`. + let workload_size = size * max(1, workload_size::() / size); + return values.par_chunks_mut(workload_size).for_each(|values| { + wavelet_transform_batch(values, size); + }); + } + match size { + 0 | 1 => {} + 2 => { + for v in values.chunks_exact_mut(2) { + v[1] += v[0] + } + } + 4 => { + for v in values.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + } + 8 => { + for v in values.chunks_exact_mut(8) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + v[5] += v[4]; + v[7] += v[6]; + v[6] += v[4]; + v[7] += v[5]; + v[4] += v[0]; + v[5] += v[1]; + v[6] += v[2]; + v[7] += v[3]; + } + } + 16 => { + for v in values.chunks_exact_mut(16) { + for v in v.chunks_exact_mut(4) { + v[1] += v[0]; + v[3] += v[2]; + v[2] += v[0]; + v[3] += v[1]; + } + let (a, v) = v.split_at_mut(4); + let (b, v) = v.split_at_mut(4); + let (c, d) = v.split_at_mut(4); + for i in 0..4 { + b[i] += a[i]; + d[i] += c[i]; + c[i] += a[i]; + d[i] += b[i]; + } + } + } + n => { + let n1 = 1 << (n.trailing_zeros() / 2); + let n2 = n / n1; + wavelet_transform_batch(values, n1); + transpose(values, n2, n1); + wavelet_transform_batch(values, n2); + transpose(values, n1, n2); + } + } +} diff --git a/whir/src/parameters.rs b/whir/src/parameters.rs new file mode 100644 index 000000000..5f806efa2 --- /dev/null +++ b/whir/src/parameters.rs @@ -0,0 +1,213 @@ +use std::{fmt::Display, marker::PhantomData, str::FromStr}; + +use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; +use serde::{Deserialize, Serialize}; + +pub fn default_max_pow(num_variables: usize, log_inv_rate: usize) -> usize { + num_variables + log_inv_rate - 3 +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum SoundnessType { + UniqueDecoding, + ProvableList, + ConjectureList, +} + +impl Display for SoundnessType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match &self { + SoundnessType::ProvableList => "ProvableList", + SoundnessType::ConjectureList => "ConjectureList", + SoundnessType::UniqueDecoding => "UniqueDecoding", + }) + } +} + +impl FromStr for SoundnessType { + type Err = String; + fn from_str(s: &str) -> Result { + if s == "ProvableList" { + Ok(SoundnessType::ProvableList) + } else if s == "ConjectureList" { + Ok(SoundnessType::ConjectureList) + } else if s == "UniqueDecoding" { + Ok(SoundnessType::UniqueDecoding) + } else { + Err(format!("Invalid soundness specification: {}", s)) + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct MultivariateParameters { + pub(crate) num_variables: usize, + _field: PhantomData, +} + +impl MultivariateParameters { + pub fn new(num_variables: usize) -> Self { + Self { + num_variables, + _field: PhantomData, + } + } +} + +impl Display for MultivariateParameters { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Number of variables: {}", self.num_variables) + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum FoldType { + Naive, + ProverHelps, +} + +impl FromStr for FoldType { + type Err = String; + fn from_str(s: &str) -> Result { + if s == "Naive" { + Ok(FoldType::Naive) + } else if s == "ProverHelps" { + Ok(FoldType::ProverHelps) + } else { + Err(format!("Invalid fold type specification: {}", s)) + } + } +} + +impl Display for FoldType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match self { + FoldType::Naive => "Naive", + FoldType::ProverHelps => "ProverHelps", + }) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum FoldingFactor { + Constant(usize), // Use the same folding factor for all rounds + ConstantFromSecondRound(usize, usize), /* Use the same folding factor for all rounds, but the first round uses a different folding factor */ +} + +impl FoldingFactor { + pub fn at_round(&self, round: usize) -> usize { + match self { + FoldingFactor::Constant(factor) => *factor, + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + if round == 0 { + *first_round_factor + } else { + *factor + } + } + } + } + + pub fn check_validity(&self, _num_variables: usize) -> Result<(), String> { + match self { + FoldingFactor::Constant(factor) => { + if *factor == 0 { + // We should at least fold some time + Err("Folding factor shouldn't be zero.".to_string()) + } else { + Ok(()) + } + } + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + if *factor == 0 || *first_round_factor == 0 { + // We should at least fold some time + Err("Folding factor shouldn't be zero.".to_string()) + } else { + Ok(()) + } + } + } + } + + /// Compute the number of WHIR rounds and the number of rounds in the final + /// sumcheck. + pub fn compute_number_of_rounds(&self, num_variables: usize) -> (usize, usize) { + match self { + FoldingFactor::Constant(factor) => { + // It's checked that factor > 0 and factor <= num_variables + let final_sumcheck_rounds = num_variables % factor; + ( + (num_variables - final_sumcheck_rounds) / factor - 1, + final_sumcheck_rounds, + ) + } + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + let nv_except_first_round = num_variables - *first_round_factor; + if nv_except_first_round < *factor { + // This case is equivalent to Constant(first_round_factor) + return (0, nv_except_first_round); + } + let final_sumcheck_rounds = nv_except_first_round % *factor; + ( + // No need to minus 1 because the initial round is already + // excepted out + (nv_except_first_round - final_sumcheck_rounds) / factor, + final_sumcheck_rounds, + ) + } + } + } + + /// Compute folding_factor(0) + ... + folding_factor(n_rounds) + pub fn total_number(&self, n_rounds: usize) -> usize { + match self { + FoldingFactor::Constant(factor) => { + // It's checked that factor > 0 and factor <= num_variables + factor * (n_rounds + 1) + } + FoldingFactor::ConstantFromSecondRound(first_round_factor, factor) => { + first_round_factor + factor * n_rounds + } + } + } +} + +#[derive(Clone)] +pub struct WhirParameters +where + MerkleConfig: Config, +{ + pub initial_statement: bool, + pub starting_log_inv_rate: usize, + pub folding_factor: FoldingFactor, + pub soundness_type: SoundnessType, + pub security_level: usize, + pub pow_bits: usize, + + pub fold_optimisation: FoldType, + + // PoW parameters + pub _pow_parameters: PhantomData, + + // Merkle tree parameters + pub leaf_hash_params: LeafParam, + pub two_to_one_params: TwoToOneParam, +} + +impl Display for WhirParameters +where + MerkleConfig: Config, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Targeting {}-bits of security with {}-bits of PoW - soundness: {:?}", + self.security_level, self.pow_bits, self.soundness_type + )?; + writeln!( + f, + "Starting rate: 2^-{}, folding_factor: {:?}, fold_opt_type: {}", + self.starting_log_inv_rate, self.folding_factor, self.fold_optimisation, + ) + } +} diff --git a/whir/src/poly_utils/coeffs.rs b/whir/src/poly_utils/coeffs.rs new file mode 100644 index 000000000..649977743 --- /dev/null +++ b/whir/src/poly_utils/coeffs.rs @@ -0,0 +1,467 @@ +use super::{MultilinearPoint, evals::EvaluationsList, hypercube::BinaryHypercubePoint}; +use crate::ntt::wavelet_transform; +use ark_ff::Field; +use ark_poly::{DenseUVPolynomial, Polynomial, univariate::DensePolynomial}; +use serde::{Deserialize, Serialize}; +#[cfg(feature = "parallel")] +use { + rayon::{join, prelude::*}, + std::mem::size_of, +}; + +/// A CoefficientList models a (multilinear) polynomial in `num_variable` variables in coefficient form. +/// +/// The order of coefficients follows the following convention: coeffs[j] corresponds to the monomial +/// determined by the binary decomposition of j with an X_i-variable present if the +/// i-th highest-significant bit among the `num_variables` least significant bits is set. +/// +/// e.g. is `num_variables` is 3 with variables X_0, X_1, X_2, then +/// - coeffs[0] is the coefficient of 1 +/// - coeffs[1] is the coefficient of X_2 +/// - coeffs[2] is the coefficient of X_1 +/// - coeffs[4] is the coefficient of X_0 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CoefficientList { + coeffs: Vec, /* list of coefficients. For multilinear polynomials, we have coeffs.len() == 1 << num_variables. */ + num_variables: usize, // number of variables +} + +impl CoefficientList +where + F: Field, +{ + fn coeff_at(&self, index: usize) -> F { + self.coeffs[index] + } + + pub fn pad_to_num_vars(&mut self, num_vars: usize) { + if self.num_variables < num_vars { + let pad = (1usize << num_vars) - (1usize << self.num_variables); + self.coeffs.extend(vec![F::ZERO; pad]); + self.num_variables = num_vars; + } + } + + /// Linearly combine the given polynomials using the given coefficients + pub fn combine(polys: &[Self], coeffs: &[F]) -> Self { + Self::new( + (0..polys[0].coeffs.len()) + .into_par_iter() + .map(|i| { + let mut acc = F::ZERO; + for (poly, coeff) in polys.iter().zip(coeffs) { + acc += poly.coeff_at(i) * coeff; + } + acc + }) + .collect(), + ) + } + + /// Evaluate the given polynomial at `point` from {0,1}^n + pub fn evaluate_hypercube(&self, point: BinaryHypercubePoint) -> F { + assert_eq!(self.coeffs.len(), 1 << self.num_variables); + assert!(point.0 < (1 << self.num_variables)); + // TODO: Optimized implementation + self.evaluate(&MultilinearPoint::from_binary_hypercube_point( + point, + self.num_variables, + )) + } + + /// Evaluate the given polynomial at `point` from F^n. + pub fn evaluate(&self, point: &MultilinearPoint) -> F { + assert_eq!(self.num_variables, point.n_variables()); + eval_multivariate(&self.coeffs, &point.0) + } + + #[inline] + fn eval_extension>(coeff: &[F], eval: &[E], scalar: E) -> E { + // explicit "return" just to simplify static code-analyzers' tasks (that can't figure out the cfg's are disjoint) + #[cfg(not(feature = "parallel"))] + return Self::eval_extension_nonparallel(coeff, eval, scalar); + #[cfg(feature = "parallel")] + return Self::eval_extension_parallel(coeff, eval, scalar); + } + + // NOTE (Gotti): This algorithm uses 2^{n+1}-1 multiplications for a polynomial in n variables. + // You could do with 2^{n}-1 by just doing a + x * b (and not forwarding scalar through the recursion at all). + // The difference comes from multiplications by E::ONE at the leaves of the recursion tree. + + // recursive helper function for polynomial evaluation: + // Note that eval(coeffs, [X_0, X1,...]) = eval(coeffs_left, [X_1,...]) + X_0 * eval(coeffs_right, [X_1,...]) + + /// Recursively compute scalar * poly_eval(coeffs;eval) where poly_eval interprets coeffs as a polynomial and eval are the evaluation points. + fn eval_extension_nonparallel>( + coeff: &[F], + eval: &[E], + scalar: E, + ) -> E { + debug_assert_eq!(coeff.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = coeff.split_at(coeff.len() / 2); + let a = Self::eval_extension_nonparallel(low, tail, scalar); + let b = Self::eval_extension_nonparallel(high, tail, scalar * x); + a + b + } else { + scalar.mul_by_base_prime_field(&coeff[0]) + } + } + + #[cfg(feature = "parallel")] + fn eval_extension_parallel>( + coeff: &[F], + eval: &[E], + scalar: E, + ) -> E { + const PARALLEL_THRESHOLD: usize = 10; + debug_assert_eq!(coeff.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = coeff.split_at(coeff.len() / 2); + if tail.len() > PARALLEL_THRESHOLD { + let (a, b) = rayon::join( + || Self::eval_extension_parallel(low, tail, scalar), + || Self::eval_extension_parallel(high, tail, scalar * x), + ); + a + b + } else { + Self::eval_extension_nonparallel(low, tail, scalar) + + Self::eval_extension_nonparallel(high, tail, scalar * x) + } + } else { + scalar.mul_by_base_prime_field(&coeff[0]) + } + } + + /// Evaluate self at `point`, where `point` is from a field extension extending the field over which the polynomial `self` is defined. + /// + /// Note that we only support the case where F is a prime field. + pub fn evaluate_at_extension>( + &self, + point: &MultilinearPoint, + ) -> E { + assert_eq!(self.num_variables, point.n_variables()); + Self::eval_extension(&self.coeffs, &point.0, E::ONE) + } + + /// Interprets self as a univariate polynomial (with coefficients of X^i in order of ascending i) and evaluates it at each point in `points`. + /// We return the vector of evaluations. + /// + /// NOTE: For the `usual` mapping between univariate and multilinear polynomials, the coefficient ordering is such that + /// for a single point x, we have (extending notation to a single point) + /// self.evaluate_at_univariate(x) == self.evaluate([x^(2^n), x^(2^{n-1}), ..., x^2, x]) + pub fn evaluate_at_univariate(&self, points: &[F]) -> Vec { + // DensePolynomial::from_coefficients_slice converts to a dense univariate polynomial. + // The coefficient order is "coefficient of 1 first". + let univariate = DensePolynomial::from_coefficients_slice(&self.coeffs); + points + .iter() + .map(|point| univariate.evaluate(point)) + .collect() + } +} + +impl CoefficientList { + pub fn new(coeffs: Vec) -> Self { + let len = coeffs.len(); + assert!(len.is_power_of_two()); + let num_variables = len.ilog2(); + + CoefficientList { + coeffs, + num_variables: num_variables as usize, + } + } + + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } + + pub fn num_variables(&self) -> usize { + self.num_variables + } + + pub fn num_coeffs(&self) -> usize { + self.coeffs.len() + } + + /// Map the polynomial `self` from F[X_1,...,X_n] to E[X_1,...,X_n], where E is a field extension of F. + /// + /// Note that this is currently restricted to the case where F is a prime field. + pub fn to_extension>(self) -> CoefficientList { + CoefficientList::new( + self.coeffs + .into_iter() + .map(E::from_base_prime_field) + .collect(), + ) + } +} + +/// Multivariate evaluation in coefficient form. +fn eval_multivariate(coeffs: &[F], point: &[F]) -> F { + debug_assert_eq!(coeffs.len(), 1 << point.len()); + match point { + [] => coeffs[0], + [x] => coeffs[0] + coeffs[1] * x, + [x0, x1] => { + let b0 = coeffs[0] + coeffs[1] * x1; + let b1 = coeffs[2] + coeffs[3] * x1; + b0 + b1 * x0 + } + [x0, x1, x2] => { + let b00 = coeffs[0] + coeffs[1] * x2; + let b01 = coeffs[2] + coeffs[3] * x2; + let b10 = coeffs[4] + coeffs[5] * x2; + let b11 = coeffs[6] + coeffs[7] * x2; + let b0 = b00 + b01 * x1; + let b1 = b10 + b11 * x1; + b0 + b1 * x0 + } + [x0, x1, x2, x3] => { + let b000 = coeffs[0] + coeffs[1] * x3; + let b001 = coeffs[2] + coeffs[3] * x3; + let b010 = coeffs[4] + coeffs[5] * x3; + let b011 = coeffs[6] + coeffs[7] * x3; + let b100 = coeffs[8] + coeffs[9] * x3; + let b101 = coeffs[10] + coeffs[11] * x3; + let b110 = coeffs[12] + coeffs[13] * x3; + let b111 = coeffs[14] + coeffs[15] * x3; + let b00 = b000 + b001 * x2; + let b01 = b010 + b011 * x2; + let b10 = b100 + b101 * x2; + let b11 = b110 + b111 * x2; + let b0 = b00 + b01 * x1; + let b1 = b10 + b11 * x1; + b0 + b1 * x0 + } + [x, tail @ ..] => { + let (b0t, b1t) = coeffs.split_at(coeffs.len() / 2); + #[cfg(not(feature = "parallel"))] + let (b0t, b1t) = (eval_multivariate(b0t, tail), eval_multivariate(b1t, tail)); + #[cfg(feature = "parallel")] + let (b0t, b1t) = { + let work_size: usize = (1 << 15) / size_of::(); + if coeffs.len() > work_size { + join( + || eval_multivariate(b0t, tail), + || eval_multivariate(b1t, tail), + ) + } else { + (eval_multivariate(b0t, tail), eval_multivariate(b1t, tail)) + } + }; + b0t + b1t * x + } + } +} + +impl CoefficientList +where + F: Field, +{ + /// fold folds the polynomial at the provided folding_randomness. + /// + /// Namely, when self is interpreted as a multi-linear polynomial f in X_0, ..., X_{n-1}, + /// it partially evaluates f at the provided `folding_randomness`. + /// Our ordering convention is to evaluate at the higher indices, i.e. we return f(X_0,X_1,..., folding_randomness[0], folding_randomness[1],...) + pub fn fold(&self, folding_randomness: &MultilinearPoint) -> Self { + let folding_factor = folding_randomness.n_variables(); + #[cfg(not(feature = "parallel"))] + let coeffs = self + .coeffs + .chunks_exact(1 << folding_factor) + .map(|coeffs| eval_multivariate(coeffs, &folding_randomness.0)) + .collect(); + #[cfg(feature = "parallel")] + let coeffs = self + .coeffs + .par_chunks_exact(1 << folding_factor) + .map(|coeffs| eval_multivariate(coeffs, &folding_randomness.0)) + .collect(); + + CoefficientList { + coeffs, + num_variables: self.num_variables() - folding_factor, + } + } +} + +impl From> for DensePolynomial +where + F: Field, +{ + fn from(value: CoefficientList) -> Self { + DensePolynomial::from_coefficients_vec(value.coeffs) + } +} + +impl From> for CoefficientList +where + F: Field, +{ + fn from(value: DensePolynomial) -> Self { + CoefficientList::new(value.coeffs) + } +} + +impl From> for EvaluationsList +where + F: Field, +{ + fn from(value: CoefficientList) -> Self { + let mut evals = value.coeffs; + wavelet_transform(&mut evals); + EvaluationsList::new(evals) + } +} + +// Previous recursive version +// impl From> for EvaluationsList +// where +// F: Field, +// { +// fn from(value: CoefficientList) -> Self { +// let num_coeffs = value.num_coeffs(); +// Base case +// if num_coeffs == 1 { +// return EvaluationsList::new(value.coeffs); +// } +// +// let half_coeffs = num_coeffs / 2; +// +// Left is polynomial with last variable set to 0 +// let mut left = Vec::with_capacity(half_coeffs); +// +// Right is polynomial with last variable set to 1 +// let mut right = Vec::with_capacity(half_coeffs); +// +// for i in 0..half_coeffs { +// left.push(value.coeffs[2 * i]); +// right.push(value.coeffs[2 * i] + value.coeffs[2 * i + 1]); +// } +// +// let left_poly = CoefficientList { +// coeffs: left, +// num_variables: value.num_variables - 1, +// }; +// let right_poly = CoefficientList { +// coeffs: right, +// num_variables: value.num_variables - 1, +// }; +// +// Compute evaluation of right and left +// let left_eval = EvaluationsList::from(left_poly); +// let right_eval = EvaluationsList::from(right_poly); +// +// Combine +// let mut evaluation_list = Vec::with_capacity(num_coeffs); +// for i in 0..half_coeffs { +// evaluation_list.push(left_eval[i]); +// evaluation_list.push(right_eval[i]); +// } +// +// EvaluationsList::new(evaluation_list) +// } +// } + +#[cfg(test)] +mod tests { + use ark_poly::{Polynomial, univariate::DensePolynomial}; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList}, + }; + + type F = Field64; + + #[test] + fn test_evaluation_conversion() { + let coeffs = vec![F::from(22), F::from(5), F::from(10), F::from(97)]; + let coeffs_list = CoefficientList::new(coeffs.clone()); + + let evaluations = EvaluationsList::from(coeffs_list); + + assert_eq!(evaluations[0], coeffs[0]); + assert_eq!(evaluations[1], coeffs[0] + coeffs[1]); + assert_eq!(evaluations[2], coeffs[0] + coeffs[2]); + assert_eq!( + evaluations[3], + coeffs[0] + coeffs[1] + coeffs[2] + coeffs[3] + ); + } + + #[test] + fn test_folding() { + let coeffs = vec![F::from(22), F::from(5), F::from(00), F::from(00)]; + let coeffs_list = CoefficientList::new(coeffs); + + let alpha = F::from(100); + let beta = F::from(32); + + let folded = coeffs_list.fold(&MultilinearPoint(vec![beta])); + + assert_eq!( + coeffs_list.evaluate(&MultilinearPoint(vec![alpha, beta])), + folded.evaluate(&MultilinearPoint(vec![alpha])) + ) + } + + #[test] + fn test_folding_and_evaluation() { + let num_variables = 10; + let coeffs = (0..(1 << num_variables)).map(F::from).collect(); + let coeffs_list = CoefficientList::new(coeffs); + + let randomness: Vec<_> = (0..num_variables).map(|i| F::from(35 * i as u64)).collect(); + for k in 0..num_variables { + let fold_part = randomness[0..k].to_vec(); + let eval_part = randomness[k..randomness.len()].to_vec(); + + let fold_random = MultilinearPoint(fold_part.clone()); + let eval_point = MultilinearPoint([eval_part.clone(), fold_part].concat()); + + let folded = coeffs_list.fold(&fold_random); + assert_eq!( + folded.evaluate(&MultilinearPoint(eval_part)), + coeffs_list.evaluate(&eval_point) + ); + } + } + + #[test] + fn test_evaluation_mv() { + let polynomial = vec![ + F::from(0), + F::from(1), + F::from(2), + F::from(3), + F::from(4), + F::from(5), + F::from(6), + F::from(7), + F::from(8), + F::from(9), + F::from(10), + F::from(11), + F::from(12), + F::from(13), + F::from(14), + F::from(15), + ]; + + let mv_poly = CoefficientList::new(polynomial); + let uv_poly: DensePolynomial<_> = mv_poly.clone().into(); + + let eval_point = F::from(4999); + assert_eq!( + uv_poly.evaluate(&F::from(1)), + F::from((0..=15).sum::()) + ); + assert_eq!( + uv_poly.evaluate(&eval_point), + mv_poly.evaluate(&MultilinearPoint::expand_from_univariate(eval_point, 4)) + ) + } +} diff --git a/whir/src/poly_utils/evals.rs b/whir/src/poly_utils/evals.rs new file mode 100644 index 000000000..d65f92954 --- /dev/null +++ b/whir/src/poly_utils/evals.rs @@ -0,0 +1,95 @@ +use std::ops::Index; + +use ark_ff::Field; + +use super::{MultilinearPoint, sequential_lag_poly::LagrangePolynomialIterator}; + +/// An EvaluationsList models a multi-linear polynomial f in `num_variables` +/// unknowns, stored via their evaluations at {0,1}^{num_variables} +/// +/// `evals` stores the evaluation in lexicographic order. +#[derive(Debug, Clone)] +pub struct EvaluationsList { + evals: Vec, + num_variables: usize, +} + +impl EvaluationsList +where + F: Field, +{ + /// Constructs a EvaluationList from the given vector `eval` of evaluations. + /// + /// The provided `evals` is supposed to be the list of evaluations, where the ordering of evaluation points in {0,1}^n + /// is lexicographic. + pub fn new(evals: Vec) -> Self { + let len = evals.len(); + assert!(len.is_power_of_two()); + let num_variables = len.ilog2(); + + EvaluationsList { + evals, + num_variables: num_variables as usize, + } + } + + /// evaluate the polynomial at `point` + pub fn evaluate(&self, point: &MultilinearPoint) -> F { + if let Some(point) = point.to_hypercube() { + return self.evals[point.0]; + } + + let mut sum = F::ZERO; + for (b, lag) in LagrangePolynomialIterator::new(point) { + sum += lag * self.evals[b.0] + } + + sum + } + + pub fn evals(&self) -> &[F] { + &self.evals + } + + pub fn evals_mut(&mut self) -> &mut [F] { + &mut self.evals + } + + pub fn num_evals(&self) -> usize { + self.evals.len() + } + + pub fn num_variables(&self) -> usize { + self.num_variables + } +} + +impl Index for EvaluationsList { + type Output = F; + fn index(&self, index: usize) -> &Self::Output { + &self.evals[index] + } +} + +#[cfg(test)] +mod tests { + use crate::poly_utils::hypercube::BinaryHypercube; + + use super::*; + use ark_ff::*; + + type F = crate::crypto::fields::Field64; + + #[test] + fn test_evaluation() { + let evaluations_vec = vec![F::ZERO, F::ONE, F::ZERO, F::ONE]; + let evals = EvaluationsList::new(evaluations_vec.clone()); + + for i in BinaryHypercube::new(2) { + assert_eq!( + evaluations_vec[i.0], + evals.evaluate(&MultilinearPoint::from_binary_hypercube_point(i, 2)) + ); + } + } +} diff --git a/whir/src/poly_utils/fold.rs b/whir/src/poly_utils/fold.rs new file mode 100644 index 000000000..5d10d036f --- /dev/null +++ b/whir/src/poly_utils/fold.rs @@ -0,0 +1,223 @@ +use crate::{ntt::intt_batch, parameters::FoldType}; +use ark_ff::{FftField, Field}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Given the evaluation of f on the coset specified by coset_offset * +/// Compute the fold on that point +pub fn compute_fold( + answers: &[F], + folding_randomness: &[F], + mut coset_offset_inv: F, + mut coset_gen_inv: F, + two_inv: F, + folding_factor: usize, +) -> F { + let mut answers = answers.to_vec(); + + // We recursively compute the fold, rec is where it is + for rec in 0..folding_factor { + let offset = answers.len() / 2; + let mut new_answers = vec![F::ZERO; offset]; + let mut coset_index_inv = F::ONE; + for i in 0..offset { + let f_value_0 = answers[i]; + let f_value_1 = answers[i + offset]; + let point_inv = coset_offset_inv * coset_index_inv; + + let left = f_value_0 + f_value_1; + let right = point_inv * (f_value_0 - f_value_1); + + new_answers[i] = + two_inv * (left + folding_randomness[folding_randomness.len() - 1 - rec] * right); + coset_index_inv *= coset_gen_inv; + } + answers = new_answers; + + // Update for next one + coset_offset_inv = coset_offset_inv * coset_offset_inv; + coset_gen_inv = coset_gen_inv * coset_gen_inv; + } + + answers[0] +} + +pub fn restructure_evaluations( + mut stacked_evaluations: Vec, + fold_type: FoldType, + _domain_gen: F, + domain_gen_inv: F, + folding_factor: usize, +) -> Vec { + let folding_size = 1_u64 << folding_factor; + assert_eq!(stacked_evaluations.len() % (folding_size as usize), 0); + match fold_type { + FoldType::Naive => stacked_evaluations, + FoldType::ProverHelps => { + // TODO: This partially undoes the NTT transform from tne encoding. + // Maybe there is a way to not do the full transform in the first place. + + // Batch inverse NTTs + intt_batch(&mut stacked_evaluations, folding_size as usize); + + // Apply coset and size correction. + // Stacked evaluation at i is f(B_l) where B_l = w^i * + let size_inv = F::from(folding_size).inverse().unwrap(); + #[cfg(not(feature = "parallel"))] + { + let mut coset_offset_inv = F::ONE; + for answers in stacked_evaluations.chunks_exact_mut(folding_size as usize) { + let mut scale = size_inv; + for v in answers.iter_mut() { + *v *= scale; + scale *= coset_offset_inv; + } + coset_offset_inv *= domain_gen_inv; + } + } + #[cfg(feature = "parallel")] + stacked_evaluations + .par_chunks_exact_mut(folding_size as usize) + .enumerate() + .for_each_with(F::ZERO, |offset, (i, answers)| { + if *offset == F::ZERO { + *offset = domain_gen_inv.pow([i as u64]); + } else { + *offset *= domain_gen_inv; + } + let mut scale = size_inv; + for v in answers.iter_mut() { + *v *= scale; + scale *= &*offset; + } + }); + + stacked_evaluations + } + } +} + +#[cfg(test)] +mod tests { + use ark_ff::{FftField, Field}; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + utils::stack_evaluations, + }; + + use super::{compute_fold, restructure_evaluations}; + + type F = Field64; + + #[test] + fn test_folding() { + let num_variables = 5; + let num_coeffs = 1 << num_variables; + + let domain_size = 256; + let folding_factor = 3; // We fold in 8 + let folding_factor_exp = 1 << folding_factor; + + let poly = CoefficientList::new((0..num_coeffs).map(F::from).collect()); + + let root_of_unity = F::get_root_of_unity(domain_size).unwrap(); + + let index = 15; + let folding_randomness: Vec<_> = (0..folding_factor).map(|i| F::from(i as u64)).collect(); + + let coset_offset = root_of_unity.pow([index]); + let coset_gen = root_of_unity.pow([domain_size / folding_factor_exp]); + + // Evaluate the polynomial on the coset + let poly_eval: Vec<_> = (0..folding_factor_exp) + .map(|i| { + poly.evaluate(&MultilinearPoint::expand_from_univariate( + coset_offset * coset_gen.pow([i]), + num_variables, + )) + }) + .collect(); + + let fold_value = compute_fold( + &poly_eval, + &folding_randomness, + coset_offset.inverse().unwrap(), + coset_gen.inverse().unwrap(), + F::from(2).inverse().unwrap(), + folding_factor, + ); + + let truth_value = poly.fold(&MultilinearPoint(folding_randomness)).evaluate( + &MultilinearPoint::expand_from_univariate( + root_of_unity.pow([folding_factor_exp * index]), + 2, + ), + ); + + assert_eq!(fold_value, truth_value); + } + + #[test] + fn test_folding_optimised() { + let num_variables = 5; + let num_coeffs = 1 << num_variables; + + let domain_size = 256; + let folding_factor = 3; // We fold in 8 + let folding_factor_exp = 1 << folding_factor; + + let poly = CoefficientList::new((0..num_coeffs).map(F::from).collect()); + + let root_of_unity = F::get_root_of_unity(domain_size).unwrap(); + let root_of_unity_inv = root_of_unity.inverse().unwrap(); + + let folding_randomness: Vec<_> = (0..folding_factor).map(|i| F::from(i as u64)).collect(); + + // Evaluate the polynomial on the domain + let domain_evaluations: Vec<_> = (0..domain_size) + .map(|w| root_of_unity.pow([w])) + .map(|point| { + poly.evaluate(&MultilinearPoint::expand_from_univariate( + point, + num_variables, + )) + }) + .collect(); + + let unprocessed = stack_evaluations(domain_evaluations, folding_factor); + + let processed = restructure_evaluations( + unprocessed.clone(), + crate::parameters::FoldType::ProverHelps, + root_of_unity, + root_of_unity_inv, + folding_factor, + ); + + let num = domain_size / folding_factor_exp; + let coset_gen_inv = root_of_unity_inv.pow([num]); + + for index in 0..num { + let offset_inv = root_of_unity_inv.pow([index]); + let span = + (index * folding_factor_exp) as usize..((index + 1) * folding_factor_exp) as usize; + + let answer_unprocessed = compute_fold( + &unprocessed[span.clone()], + &folding_randomness, + offset_inv, + coset_gen_inv, + F::from(2).inverse().unwrap(), + folding_factor, + ); + + let answer_processed = CoefficientList::new(processed[span].to_vec()) + .evaluate(&MultilinearPoint(folding_randomness.clone())); + + assert_eq!(answer_processed, answer_unprocessed); + } + } +} diff --git a/whir/src/poly_utils/gray_lag_poly.rs b/whir/src/poly_utils/gray_lag_poly.rs new file mode 100644 index 000000000..e2749ffe8 --- /dev/null +++ b/whir/src/poly_utils/gray_lag_poly.rs @@ -0,0 +1,151 @@ +// NOTE: This is the one from Ron's + +use ark_ff::{Field, batch_inversion}; + +use super::{MultilinearPoint, hypercube::BinaryHypercubePoint}; + +pub struct LagrangePolynomialGray { + position_bin: usize, + position_gray: usize, + value: F, + precomputed: Vec, + num_variables: usize, +} + +impl LagrangePolynomialGray { + pub fn new(point: &MultilinearPoint) -> Self { + let num_variables = point.n_variables(); + // Limitation for bin hypercube + assert!(point.0.iter().all(|&p| p != F::ZERO && p != F::ONE)); + + // This is negated[i] = eq_poly(z_i, 0) = 1 - z_i + let negated_points: Vec<_> = point.0.iter().map(|z| F::ONE - z).collect(); + // This is points[i] = eq_poly(z_i, 1) = z_i + let points = point.0.to_vec(); + + let mut to_invert = [negated_points.clone(), points.clone()].concat(); + batch_inversion(&mut to_invert); + + let (denom_0, denom_1) = to_invert.split_at(num_variables); + + let mut precomputed = vec![F::ZERO; 2 * num_variables]; + for n in 0..num_variables { + precomputed[2 * n] = points[n] * denom_0[n]; + precomputed[2 * n + 1] = negated_points[n] * denom_1[n]; + } + + LagrangePolynomialGray { + position_gray: gray_encode(0), + position_bin: 0, + value: negated_points.into_iter().product(), + num_variables, + precomputed, + } + } +} + +pub fn gray_encode(integer: usize) -> usize { + (integer >> 1) ^ integer +} + +pub fn gray_decode(integer: usize) -> usize { + match integer { + 0 => 0, + _ => integer ^ gray_decode(integer >> 1), + } +} + +impl Iterator for LagrangePolynomialGray { + type Item = (BinaryHypercubePoint, F); + + fn next(&mut self) -> Option { + if self.position_bin >= (1 << self.num_variables) { + return None; + } + + let result = (BinaryHypercubePoint(self.position_gray), self.value); + + let prev = self.position_gray; + + self.position_bin += 1; + self.position_gray = gray_encode(self.position_bin); + + if self.position_bin < (1 << self.num_variables) { + let diff = prev ^ self.position_gray; + let i = (self.num_variables - 1) - diff.trailing_zeros() as usize; + let flip = (diff & self.position_gray == 0) as usize; + + self.value *= self.precomputed[2 * i + flip]; + } + + Some(result) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::{ + crypto::fields::Field64, + poly_utils::{ + MultilinearPoint, eq_poly, + gray_lag_poly::{LagrangePolynomialGray, gray_decode}, + hypercube::BinaryHypercubePoint, + }, + }; + + use super::gray_encode; + + type F = Field64; + + #[test] + fn test_gray_ordering() { + let values = [ + (0b0000, 0b0000), + (0b0001, 0b0001), + (0b0010, 0b0011), + (0b0011, 0b0010), + (0b0100, 0b0110), + (0b0101, 0b0111), + (0b0110, 0b0101), + (0b0111, 0b0100), + (0b1000, 0b1100), + (0b1001, 0b1101), + (0b1010, 0b1111), + (0b1011, 0b1110), + (0b1100, 0b1010), + (0b1101, 0b1011), + (0b1110, 0b1001), + (0b1111, 0b1000), + ]; + + for (bin, gray) in values { + assert_eq!(gray_encode(bin), gray); + assert_eq!(gray_decode(gray), bin); + } + } + + #[test] + fn test_gray_ordering_iterator() { + let point = MultilinearPoint(vec![F::from(2), F::from(3), F::from(4)]); + + for (i, (b, _)) in LagrangePolynomialGray::new(&point).enumerate() { + assert_eq!(b.0, gray_encode(i)); + } + } + + #[test] + fn test_gray() { + let point = MultilinearPoint(vec![F::from(2), F::from(3), F::from(4)]); + + let eq_poly_res: BTreeSet<_> = (0..(1 << 3)) + .map(BinaryHypercubePoint) + .map(|b| (b, eq_poly(&point, b))) + .collect(); + + let gray_res: BTreeSet<_> = LagrangePolynomialGray::new(&point).collect(); + + assert_eq!(eq_poly_res, gray_res); + } +} diff --git a/whir/src/poly_utils/hypercube.rs b/whir/src/poly_utils/hypercube.rs new file mode 100644 index 000000000..a5cd53d71 --- /dev/null +++ b/whir/src/poly_utils/hypercube.rs @@ -0,0 +1,40 @@ +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +// TODO (Gotti): Should pos rather be a u64? usize is platform-dependent, giving a platform-dependent limit on the number of variables. +// num_variables may be smaller as well. + +// NOTE: Conversion BinaryHypercube <-> MultilinearPoint is Big Endian, using only the num_variables least significant bits of the number stored inside BinaryHypercube. + +/// point on the binary hypercube {0,1}^n for some n. +/// +/// The point is encoded via the n least significant bits of a usize in big endian order and we do not store n. +pub struct BinaryHypercubePoint(pub usize); + +/// BinaryHypercube is an Iterator that is used to range over the points of the hypercube {0,1}^n, where n == `num_variables` +pub struct BinaryHypercube { + pos: usize, // current position, encoded via the bits of pos + num_variables: usize, // dimension of the hypercube +} + +impl BinaryHypercube { + pub fn new(num_variables: usize) -> Self { + debug_assert!(num_variables < usize::BITS as usize); // Note that we need strictly smaller, since some code would overflow otherwise. + BinaryHypercube { + pos: 0, + num_variables, + } + } +} + +impl Iterator for BinaryHypercube { + type Item = BinaryHypercubePoint; + + fn next(&mut self) -> Option { + let curr = self.pos; + if curr < (1 << self.num_variables) { + self.pos += 1; + Some(BinaryHypercubePoint(curr)) + } else { + None + } + } +} diff --git a/whir/src/poly_utils/mod.rs b/whir/src/poly_utils/mod.rs new file mode 100644 index 000000000..107574263 --- /dev/null +++ b/whir/src/poly_utils/mod.rs @@ -0,0 +1,325 @@ +use ark_ff::Field; +use rand::{ + Rng, RngCore, + distributions::{Distribution, Standard}, +}; + +use crate::utils::to_binary; + +use self::hypercube::BinaryHypercubePoint; + +pub mod coeffs; +pub mod evals; +pub mod fold; +pub mod gray_lag_poly; +pub mod hypercube; +pub mod sequential_lag_poly; +pub mod streaming_evaluation_helper; + +/// Point (x_1,..., x_n) in F^n for some n. Often, the x_i are binary. +/// For the latter case, we also have BinaryHypercubePoint. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MultilinearPoint(pub Vec); + +impl MultilinearPoint +where + F: Field, +{ + /// returns the number of variables. + pub fn n_variables(&self) -> usize { + self.0.len() + } + + // NOTE: Conversion BinaryHypercube <-> MultilinearPoint converts a + // multilinear point (x1,x2,...,x_n) into the number with bit-pattern 0...0 x_1 x_2 ... x_n, provided all x_i are in {0,1}. + // That means we pad zero bits in BinaryHypercube from the msb end and use big-endian for the actual conversion. + + /// Creates a MultilinearPoint from a BinaryHypercubePoint; the latter models the same thing, but is restricted to binary entries. + pub fn from_binary_hypercube_point(point: BinaryHypercubePoint, num_variables: usize) -> Self { + Self( + to_binary(point.0, num_variables) + .into_iter() + .map(|x| if x { F::ONE } else { F::ZERO }) + .collect(), + ) + } + + /// Converts to a BinaryHypercubePoint, provided the MultilinearPoint is actually in {0,1}^n. + pub fn to_hypercube(&self) -> Option { + let mut counter = 0; + for &coord in &self.0 { + if coord == F::ZERO { + counter <<= 1; + } else if coord == F::ONE { + counter = (counter << 1) + 1; + } else { + return None; + } + } + + Some(BinaryHypercubePoint(counter)) + } + + /// converts a univariate evaluation point into a multilinear one. + /// + /// Notably, consider the usual bijection + /// {multilinear polys in n variables} <-> {univariate polys of deg < 2^n} + /// f(x_1,...x_n) <-> g(y) := f(y^(2^(n-1), ..., y^4, y^2, y). + /// x_1^i_1 * ... *x_n^i_n <-> y^i, where (i_1,...,i_n) is the (big-endian) binary decomposition of i. + /// + /// expand_from_univariate maps the evaluation points to the multivariate domain, i.e. + /// f(expand_from_univariate(y)) == g(y). + /// in a way that is compatible with our endianness choices. + pub fn expand_from_univariate(point: F, num_variables: usize) -> Self { + let mut res = Vec::with_capacity(num_variables); + let mut cur = point; + for _ in 0..num_variables { + res.push(cur); + cur = cur * cur; + } + + // Reverse so higher power is first + res.reverse(); + + MultilinearPoint(res) + } +} + +/// creates a random MultilinearPoint of length `num_variables` using the RNG `rng`. +impl MultilinearPoint +where + Standard: Distribution, +{ + pub fn rand(rng: &mut impl RngCore, num_variables: usize) -> Self { + MultilinearPoint((0..num_variables).map(|_| rng.gen()).collect()) + } +} + +/// creates a MultilinearPoint of length 1 from a single field element +impl From for MultilinearPoint { + fn from(value: F) -> Self { + MultilinearPoint(vec![value]) + } +} + +/// Compute eq(coords,point), where eq is the equality polynomial, where point is binary. +/// +/// Recall that the equality polynomial eq(c, p) is defined as eq(c,p) == \prod_i c_i * p_i + (1-c_i)*(1-p_i). +/// Note that for fixed p, viewed as a polynomial in c, it is the interpolation polynomial associated to the evaluation point p in the evaluation set {0,1}^n. +pub fn eq_poly(coords: &MultilinearPoint, point: BinaryHypercubePoint) -> F +where + F: Field, +{ + let mut point = point.0; + let n_variables = coords.n_variables(); + assert!(point < (1 << n_variables)); // check that the lengths of coords and point match. + + let mut acc = F::ONE; + + for val in coords.0.iter().rev() { + let b = point % 2; + acc *= if b == 1 { *val } else { F::ONE - *val }; + point >>= 1; + } + + acc +} + +/// Compute eq(coords,point), where eq is the equality polynomial and where point is not neccessarily binary. +/// +/// Recall that the equality polynomial eq(c, p) is defined as eq(c,p) == \prod_i c_i * p_i + (1-c_i)*(1-p_i). +/// Note that for fixed p, viewed as a polynomial in c, it is the interpolation polynomial associated to the evaluation point p in the evaluation set {0,1}^n. +pub fn eq_poly_outside(coords: &MultilinearPoint, point: &MultilinearPoint) -> F +where + F: Field, +{ + assert_eq!(coords.n_variables(), point.n_variables()); + + let mut acc = F::ONE; + + for (&l, &r) in coords.0.iter().zip(&point.0) { + acc *= l * r + (F::ONE - l) * (F::ONE - r); + } + + acc +} + +// TODO: Precompute two_inv? +// Alternatively, compute it directly without the general (and slow) .inverse() map. + +/// Compute eq3(coords,point), where eq3 is the equality polynomial for {0,1,2}^n and point is interpreted as an element from {0,1,2}^n via (big Endian) ternary decomposition. +/// +/// eq3(coords, point) is the unique polynomial of degree <=2 in each variable, s.t. +/// for coords, point in {0,1,2}^n, we have: +/// eq3(coords,point) = 1 if coords == point and 0 otherwise. +pub fn eq_poly3(coords: &MultilinearPoint, mut point: usize) -> F +where + F: Field, +{ + let two = F::ONE + F::ONE; + let two_inv = two.inverse().unwrap(); + + let n_variables = coords.n_variables(); + assert!(point < 3usize.pow(n_variables as u32)); + + let mut acc = F::ONE; + + // Note: This iterates over the ternary decomposition least-significant trit(?) first. + // Since our convention is big endian, we reverse the order of coords to account for this. + for &val in coords.0.iter().rev() { + let b = point % 3; + acc *= match b { + 0 => (val - F::ONE) * (val - two) * two_inv, + 1 => val * (val - two) * (-F::ONE), + 2 => val * (val - F::ONE) * two_inv, + _ => unreachable!(), + }; + point /= 3; + } + + acc +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{eq_poly, eq_poly3, hypercube::BinaryHypercube}, + }; + + use super::{BinaryHypercubePoint, MultilinearPoint, coeffs::CoefficientList}; + + type F = Field64; + + #[test] + fn test_equality() { + let point = MultilinearPoint(vec![F::from(0), F::from(0)]); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b00)), F::from(1)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b01)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b10)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b11)), F::from(0)); + + let point = MultilinearPoint(vec![F::from(1), F::from(0)]); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b00)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b01)), F::from(0)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b10)), F::from(1)); + assert_eq!(eq_poly(&point, BinaryHypercubePoint(0b11)), F::from(0)); + } + + #[test] + fn test_equality_again() { + let poly = CoefficientList::new(vec![F::from(35), F::from(97), F::from(10), F::from(32)]); + let point = MultilinearPoint(vec![F::from(42), F::from(36)]); + let eval = poly.evaluate(&point); + + assert_eq!( + eval, + BinaryHypercube::new(2) + .map( + |i| poly.evaluate(&MultilinearPoint::from_binary_hypercube_point(i, 2)) + * eq_poly(&point, i) + ) + .sum() + ); + } + + #[test] + fn test_equality3() { + let point = MultilinearPoint(vec![F::from(0), F::from(0)]); + + assert_eq!(eq_poly3(&point, 0), F::from(1)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(0)); + assert_eq!(eq_poly3(&point, 3), F::from(0)); + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(0)); + + let point = MultilinearPoint(vec![F::from(1), F::from(0)]); + + assert_eq!(eq_poly3(&point, 0), F::from(0)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(0)); + assert_eq!(eq_poly3(&point, 3), F::from(1)); // 3 corresponds to ternary (1,0) + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(0)); + + let point = MultilinearPoint(vec![F::from(0), F::from(2)]); + + assert_eq!(eq_poly3(&point, 0), F::from(0)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(1)); // 2 corresponds to ternary (0,2) + assert_eq!(eq_poly3(&point, 3), F::from(0)); + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(0)); + + let point = MultilinearPoint(vec![F::from(2), F::from(2)]); + + assert_eq!(eq_poly3(&point, 0), F::from(0)); + assert_eq!(eq_poly3(&point, 1), F::from(0)); + assert_eq!(eq_poly3(&point, 2), F::from(0)); + assert_eq!(eq_poly3(&point, 3), F::from(0)); + assert_eq!(eq_poly3(&point, 4), F::from(0)); + assert_eq!(eq_poly3(&point, 5), F::from(0)); + assert_eq!(eq_poly3(&point, 6), F::from(0)); + assert_eq!(eq_poly3(&point, 7), F::from(0)); + assert_eq!(eq_poly3(&point, 8), F::from(1)); // 8 corresponds to ternary (2,2) + } + + #[test] + #[should_panic] + fn test_equality_2() { + let coords = MultilinearPoint(vec![F::from(0), F::from(0)]); + + // implicit length of BinaryHypercubePoint is (at least) 3, exceeding lenth of coords + let _x = eq_poly(&coords, BinaryHypercubePoint(0b100)); + } + + #[test] + fn expand_from_univariate() { + let num_variables = 4; + + let point0 = MultilinearPoint::expand_from_univariate(F::from(0), num_variables); + let point1 = MultilinearPoint::expand_from_univariate(F::from(1), num_variables); + let point2 = MultilinearPoint::expand_from_univariate(F::from(2), num_variables); + + assert_eq!(point0.n_variables(), num_variables); + assert_eq!(point1.n_variables(), num_variables); + assert_eq!(point2.n_variables(), num_variables); + + assert_eq!( + MultilinearPoint::from_binary_hypercube_point(BinaryHypercubePoint(0), num_variables), + point0 + ); + + assert_eq!( + MultilinearPoint::from_binary_hypercube_point( + BinaryHypercubePoint((1 << num_variables) - 1), + num_variables + ), + point1 + ); + + assert_eq!( + MultilinearPoint(vec![F::from(256), F::from(16), F::from(4), F::from(2)]), + point2 + ); + } + + #[test] + fn from_hypercube_and_back() { + let hypercube_point = BinaryHypercubePoint(24); + assert_eq!( + Some(hypercube_point), + MultilinearPoint::::from_binary_hypercube_point(hypercube_point, 5).to_hypercube() + ); + } +} diff --git a/whir/src/poly_utils/sequential_lag_poly.rs b/whir/src/poly_utils/sequential_lag_poly.rs new file mode 100644 index 000000000..56284dedf --- /dev/null +++ b/whir/src/poly_utils/sequential_lag_poly.rs @@ -0,0 +1,174 @@ +// NOTE: This is the one from Blendy + +use ark_ff::Field; + +use super::{MultilinearPoint, hypercube::BinaryHypercubePoint}; + +/// There is an alternative (possibly more efficient) implementation that iterates over the x in Gray code ordering. +/// LagrangePolynomialIterator for a given multilinear n-dimensional `point` iterates over pairs (x, y) +/// where x ranges over all possible {0,1}^n +/// and y equals the product y_1 * ... * y_n where +/// +/// y_i = point[i] if x_i == 1 +/// y_i = 1-point[i] if x_i == 0 +/// +/// This means that y == eq_poly(point, x) +pub struct LagrangePolynomialIterator { + last_position: Option, /* the previously output BinaryHypercubePoint (encoded as usize). None before the first output. */ + point: Vec, /* stores a copy of the `point` given when creating the iterator. For easier(?) bit-fiddling, we store in in reverse order. */ + point_negated: Vec, /* stores the precomputed values 1-point[i] in the same ordering as point. */ + /// stack Stores the n+1 values (in order) 1, y_1, y_1*y_2, y_1*y_2*y_3, ..., y_1*...*y_n for the previously output y. + /// Before the first iteration (if last_position == None), it stores the values for the next (i.e. first) output instead. + stack: Vec, + num_variables: usize, // dimension +} + +impl LagrangePolynomialIterator { + pub fn new(point: &MultilinearPoint) -> Self { + let num_variables = point.0.len(); + + // Initialize a stack with capacity for messages/ message_hats and the identity element + let mut stack: Vec = Vec::with_capacity(point.0.len() + 1); + stack.push(F::ONE); + + let mut point = point.0.clone(); + let mut point_negated: Vec<_> = point.iter().map(|x| F::ONE - *x).collect(); + // Iterate over the message_hats, update the running product, and push it onto the stack + let mut running_product: F = F::ONE; + for point_neg in &point_negated { + running_product *= point_neg; + stack.push(running_product); + } + + point.reverse(); + point_negated.reverse(); + + // Return + Self { + num_variables, + point, + point_negated, + stack, + last_position: None, + } + } +} + +impl Iterator for LagrangePolynomialIterator { + type Item = (BinaryHypercubePoint, F); + // Iterator implementation for the struct + fn next(&mut self) -> Option { + // a) Check if this is the first iteration + if self.last_position.is_none() { + // Initialize last position + self.last_position = Some(0); + // Return the top of the stack + return Some((BinaryHypercubePoint(0), *self.stack.last().unwrap())); + } + + // b) Check if in the last iteration we finished iterating + if self.last_position.unwrap() + 1 >= 1 << self.num_variables { + return None; + } + + // c) Everything else, first get bit diff + let last_position = self.last_position.unwrap(); + let next_position = last_position + 1; + let bit_diff = last_position ^ next_position; + + // Determine the shared prefix of the most significant bits + let low_index_of_prefix = (bit_diff + 1).trailing_zeros() as usize; + + // Discard any stack values outside of this prefix + self.stack.truncate(self.stack.len() - low_index_of_prefix); + + // Iterate up to this prefix computing lag poly correctly + for bit_index in (0..low_index_of_prefix).rev() { + let last_element = self.stack.last().unwrap(); + let next_bit: bool = (next_position & (1 << bit_index)) != 0; + self.stack.push(match next_bit { + true => *last_element * self.point[bit_index], + false => *last_element * self.point_negated[bit_index], + }); + } + + // Don't forget to update the last position + self.last_position = Some(next_position); + + // Return the top of the stack + Some(( + BinaryHypercubePoint(next_position), + *self.stack.last().unwrap(), + )) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, eq_poly, hypercube::BinaryHypercubePoint}, + }; + + use super::LagrangePolynomialIterator; + + type F = Field64; + + #[test] + fn test_blendy() { + let one = F::from(1); + let (a, b) = (F::from(2), F::from(3)); + let point_1 = MultilinearPoint(vec![a, b]); + + let mut lag_iterator = LagrangePolynomialIterator::new(&point_1); + + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(0), (one - a) * (one - b)) + ); + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(1), (one - a) * b) + ); + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(2), a * (one - b)) + ); + assert_eq!( + lag_iterator.next().unwrap(), + (BinaryHypercubePoint(3), a * b) + ); + assert_eq!(lag_iterator.next(), None); + } + + #[test] + fn test_blendy_2() { + let point = MultilinearPoint(vec![F::from(12), F::from(13), F::from(32)]); + + let mut last_b = None; + for (b, lag) in LagrangePolynomialIterator::new(&point) { + assert_eq!(eq_poly(&point, b), lag); + assert!(b.0 < 1 << 3); + last_b = Some(b); + } + assert_eq!(last_b, Some(BinaryHypercubePoint(7))); + } + + #[test] + fn test_blendy_3() { + let point = MultilinearPoint(vec![ + F::from(414151), + F::from(109849018), + F::from(33184190), + F::from(33184190), + F::from(33184190), + ]); + + let mut last_b = None; + for (b, lag) in LagrangePolynomialIterator::new(&point) { + assert_eq!(eq_poly(&point, b), lag); + last_b = Some(b); + } + assert_eq!(last_b, Some(BinaryHypercubePoint(31))); + } +} diff --git a/whir/src/poly_utils/streaming_evaluation_helper.rs b/whir/src/poly_utils/streaming_evaluation_helper.rs new file mode 100644 index 000000000..6e0a3a853 --- /dev/null +++ b/whir/src/poly_utils/streaming_evaluation_helper.rs @@ -0,0 +1,82 @@ +// NOTE: This is the one from Blendy adapted for streaming evals + +use ark_ff::Field; + +use super::{MultilinearPoint, hypercube::BinaryHypercubePoint}; + +pub struct TermPolynomialIterator { + last_position: Option, + point: Vec, + stack: Vec, + num_variables: usize, +} + +impl TermPolynomialIterator { + pub fn new(point: &MultilinearPoint) -> Self { + let num_variables = point.0.len(); + + // Initialize a stack with capacity for messages/ message_hats and the identity element + let stack: Vec = vec![F::ONE; point.0.len() + 1]; + + let mut point = point.0.clone(); + + point.reverse(); + + // Return + Self { + num_variables, + point, + stack, + last_position: None, + } + } +} + +impl Iterator for TermPolynomialIterator { + type Item = (BinaryHypercubePoint, F); + // Iterator implementation for the struct + fn next(&mut self) -> Option { + // a) Check if this is the first iteration + if self.last_position.is_none() { + // Initialize last position + self.last_position = Some(0); + // Return the top of the stack + return Some((BinaryHypercubePoint(0), *self.stack.last().unwrap())); + } + + // b) Check if in the last iteration we finished iterating + if self.last_position.unwrap() + 1 >= 1 << self.num_variables { + return None; + } + + // c) Everything else, first get bit diff + let last_position = self.last_position.unwrap(); + let next_position = last_position + 1; + let bit_diff = last_position ^ next_position; + + // Determine the shared prefix of the most significant bits + let low_index_of_prefix = (bit_diff + 1).trailing_zeros() as usize; + + // Discard any stack values outside of this prefix + self.stack.truncate(self.stack.len() - low_index_of_prefix); + + // Iterate up to this prefix computing lag poly correctly + for bit_index in (0..low_index_of_prefix).rev() { + let last_element = self.stack.last().unwrap(); + let next_bit: bool = (next_position & (1 << bit_index)) != 0; + self.stack.push(match next_bit { + true => *last_element * self.point[bit_index], + false => *last_element, + }); + } + + // Don't forget to update the last position + self.last_position = Some(next_position); + + // Return the top of the stack + Some(( + BinaryHypercubePoint(next_position), + *self.stack.last().unwrap(), + )) + } +} diff --git a/whir/src/sumcheck/mod.rs b/whir/src/sumcheck/mod.rs new file mode 100644 index 000000000..943a51d02 --- /dev/null +++ b/whir/src/sumcheck/mod.rs @@ -0,0 +1,321 @@ +pub mod proof; +pub mod prover_batched; +pub mod prover_core; +pub mod prover_not_skipping; +pub mod prover_not_skipping_batched; +pub mod prover_single; + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + }; + + use super::prover_core::SumcheckCore; + + type F = Field64; + + #[test] + fn test_sumcheck_folding_factor_1() { + let folding_factor = 1; + let eval_point = MultilinearPoint(vec![F::from(10), F::from(11)]); + let polynomial = + CoefficientList::new(vec![F::from(1), F::from(5), F::from(10), F::from(14)]); + + let claimed_value = polynomial.evaluate(&eval_point); + + let mut prover = SumcheckCore::new(polynomial, &[eval_point], &[F::from(1)]); + + let poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly_1.sum_over_hypercube(), claimed_value); + + let combination_randomness = F::from(100101); + let folding_randomness = MultilinearPoint(vec![F::from(4999)]); + + prover.compress(folding_factor, combination_randomness, &folding_randomness); + + let poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_2.sum_over_hypercube(), + combination_randomness * poly_1.evaluate_at_point(&folding_randomness) + ); + } + + #[test] + fn test_single_folding() { + let num_variables = 2; + let folding_factor = 2; + let polynomial = CoefficientList::new(vec![F::from(1), F::from(2), F::from(3), F::from(4)]); + + let ood_point = MultilinearPoint::expand_from_univariate(F::from(2), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(3), num_variables); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + + let epsilon_1 = F::from(10); + let epsilon_2 = F::from(100); + + let prover = SumcheckCore::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + ); + + let poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_1.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + let folding_randomness = MultilinearPoint(vec![F::from(400000), F::from(800000)]); + + let poly_eval = polynomial.evaluate(&folding_randomness); + let v_eval = epsilon_1 * eq_poly_outside(&ood_point, &folding_randomness) + + epsilon_2 * eq_poly_outside(&statement_point, &folding_randomness); + + assert_eq!( + poly_1.evaluate_at_point(&folding_randomness), + poly_eval * v_eval + ); + } + + #[test] + fn test_sumcheck_folding_factor_2() { + let num_variables = 6; + let folding_factor = 2; + let eval_point = MultilinearPoint(vec![F::from(97); num_variables]); + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + let claimed_value = polynomial.evaluate(&eval_point); + + let mut prover = SumcheckCore::new(polynomial.clone(), &[eval_point], &[F::from(1)]); + + let poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly_1.sum_over_hypercube(), claimed_value); + + let combination_randomness = [F::from(293), F::from(42)]; + let folding_randomness = MultilinearPoint(vec![F::from(335), F::from(222)]); + + let new_eval_point = MultilinearPoint(vec![F::from(32); num_variables - folding_factor]); + let folded_polynomial = polynomial.fold(&folding_randomness); + let new_fold_eval = folded_polynomial.evaluate(&new_eval_point); + + prover.compress( + folding_factor, + combination_randomness[0], + &folding_randomness, + ); + prover.add_new_equality(&[new_eval_point], &combination_randomness[1..]); + + let poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_2.sum_over_hypercube(), + combination_randomness[0] * poly_1.evaluate_at_point(&folding_randomness) + + combination_randomness[1] * new_fold_eval + ); + + let combination_randomness = F::from(23212); + prover.compress(folding_factor, combination_randomness, &folding_randomness); + + let poly_3 = prover.compute_sumcheck_polynomial(folding_factor); + + assert_eq!( + poly_3.sum_over_hypercube(), + combination_randomness * poly_2.evaluate_at_point(&folding_randomness) + ) + } + + #[test] + fn test_e2e() { + let num_variables = 4; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + let folding_randomness_1 = MultilinearPoint(vec![F::from(11), F::from(31)]); + let fold_point = MultilinearPoint(vec![F::from(31), F::from(15)]); + let combination_randomness = [F::from(31), F::from(4999)]; + let folding_randomness_2 = MultilinearPoint(vec![F::from(97), F::from(36)]); + + let mut prover = SumcheckCore::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + ); + + let sumcheck_poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + let folded_poly_1 = polynomial.fold(&folding_randomness_1.clone()); + prover.compress( + folding_factor, + combination_randomness[0], + &folding_randomness_1, + ); + prover.add_new_equality(&[fold_point.clone()], &combination_randomness[1..]); + + let sumcheck_poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + + assert_eq!( + sumcheck_poly_1.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + let fold_answer = folded_poly_1.evaluate(&fold_point); + + assert_eq!( + sumcheck_poly_2.sum_over_hypercube(), + combination_randomness[0] * sumcheck_poly_1.evaluate_at_point(&folding_randomness_1) + + combination_randomness[1] * fold_answer + ); + + let full_folding = + MultilinearPoint([folding_randomness_2.0.clone(), folding_randomness_1.0].concat()); + let eval_coeff = folded_poly_1.fold(&folding_randomness_2).coeffs()[0]; + assert_eq!( + sumcheck_poly_2.evaluate_at_point(&folding_randomness_2), + eval_coeff + * (combination_randomness[0] + * (epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + + combination_randomness[1] + * eq_poly_outside(&folding_randomness_2, &fold_point)) + ) + } + + #[test] + fn test_e2e_larger() { + let num_variables = 6; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + let folding_randomness_1 = MultilinearPoint(vec![F::from(11), F::from(31)]); + let folding_randomness_2 = MultilinearPoint(vec![F::from(97), F::from(36)]); + let folding_randomness_3 = MultilinearPoint(vec![F::from(11297), F::from(42136)]); + let fold_point_11 = + MultilinearPoint(vec![F::from(31), F::from(15), F::from(31), F::from(15)]); + let fold_point_12 = + MultilinearPoint(vec![F::from(1231), F::from(15), F::from(4231), F::from(15)]); + let fold_point_2 = MultilinearPoint(vec![F::from(311), F::from(115)]); + let combination_randomness_1 = [F::from(1289), F::from(3281), F::from(10921)]; + let combination_randomness_2 = [F::from(3281), F::from(3232)]; + + let mut prover = SumcheckCore::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + ); + + let sumcheck_poly_1 = prover.compute_sumcheck_polynomial(folding_factor); + + let folded_poly_1 = polynomial.fold(&folding_randomness_1.clone()); + prover.compress( + folding_factor, + combination_randomness_1[0], + &folding_randomness_1, + ); + prover.add_new_equality( + &[fold_point_11.clone(), fold_point_12.clone()], + &combination_randomness_1[1..], + ); + + let sumcheck_poly_2 = prover.compute_sumcheck_polynomial(folding_factor); + + let folded_poly_2 = folded_poly_1.fold(&folding_randomness_2.clone()); + prover.compress( + folding_factor, + combination_randomness_2[0], + &folding_randomness_2, + ); + prover.add_new_equality(&[fold_point_2.clone()], &combination_randomness_2[1..]); + + let sumcheck_poly_3 = prover.compute_sumcheck_polynomial(folding_factor); + let final_coeff = folded_poly_2.fold(&folding_randomness_3.clone()).coeffs()[0]; + + // Compute all evaluations + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + let fold_answer_11 = folded_poly_1.evaluate(&fold_point_11); + let fold_answer_12 = folded_poly_1.evaluate(&fold_point_12); + let fold_answer_2 = folded_poly_2.evaluate(&fold_point_2); + + assert_eq!( + sumcheck_poly_1.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + assert_eq!( + sumcheck_poly_2.sum_over_hypercube(), + combination_randomness_1[0] * sumcheck_poly_1.evaluate_at_point(&folding_randomness_1) + + combination_randomness_1[1] * fold_answer_11 + + combination_randomness_1[2] * fold_answer_12 + ); + + assert_eq!( + sumcheck_poly_3.sum_over_hypercube(), + combination_randomness_2[0] * sumcheck_poly_2.evaluate_at_point(&folding_randomness_2) + + combination_randomness_2[1] * fold_answer_2 + ); + + let full_folding = MultilinearPoint( + [ + folding_randomness_3.0.clone(), + folding_randomness_2.0.clone(), + folding_randomness_1.0, + ] + .concat(), + ); + + assert_eq!( + sumcheck_poly_3.evaluate_at_point(&folding_randomness_3), + final_coeff + * (combination_randomness_2[0] + * (combination_randomness_1[0] + * (epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + + combination_randomness_1[1] + * eq_poly_outside( + &fold_point_11, + &MultilinearPoint( + [ + folding_randomness_3.0.clone(), + folding_randomness_2.0.clone() + ] + .concat() + ) + ) + + combination_randomness_1[2] + * eq_poly_outside( + &fold_point_12, + &MultilinearPoint( + [folding_randomness_3.0.clone(), folding_randomness_2.0] + .concat() + ) + )) + + combination_randomness_2[1] + * eq_poly_outside(&folding_randomness_3, &fold_point_2)) + ) + } +} diff --git a/whir/src/sumcheck/proof.rs b/whir/src/sumcheck/proof.rs new file mode 100644 index 000000000..b3e486403 --- /dev/null +++ b/whir/src/sumcheck/proof.rs @@ -0,0 +1,101 @@ +use ark_ff::Field; + +use crate::{ + poly_utils::{MultilinearPoint, eq_poly3}, + utils::base_decomposition, +}; + +// Stored in evaluation form +#[derive(Debug, Clone)] +pub struct SumcheckPolynomial { + n_variables: usize, // number of variables; + // evaluations has length 3^{n_variables} + // The order in which it is stored is such that evaluations[i] + // corresponds to the evaluation at utils::base_decomposition(i, 3, n_variables), + // which performs (big-endian) ternary decomposition. + // (in other words, the ordering is lexicographic wrt the evaluation point) + evaluations: Vec, /* Each of our polynomials will be in F^{<3}[X_1, \dots, X_k], + * so it us uniquely determined by it's evaluations over {0, 1, 2}^k */ +} + +impl SumcheckPolynomial +where + F: Field, +{ + pub fn new(evaluations: Vec, n_variables: usize) -> Self { + SumcheckPolynomial { + evaluations, + n_variables, + } + } + + /// Returns the vector of evaluations at {0,1,2}^n_variables of the polynomial f + /// in the following order: [f(0,0,..,0), f(0,0,..,1), f(0,0,...,2), f(0,0,...,1,0), ...] + /// (i.e. lexicographic wrt. to the evaluation points. + pub fn evaluations(&self) -> &[F] { + &self.evaluations + } + + // TODO(Gotti): Rename to sum_over_binary_hypercube for clarity? + // TODO(Gotti): Make more efficient; the base_decomposition and filtering is unneccessary. + + /// Returns the sum of evaluations of f, when summed only over {0,1}^n_variables + /// + /// (and not over {0,1,2}^n_variable) + pub fn sum_over_hypercube(&self) -> F { + let num_evaluation_points = 3_usize.pow(self.n_variables as u32); + + let mut sum = F::ZERO; + for point in 0..num_evaluation_points { + if base_decomposition(point, 3, self.n_variables) + .into_iter() + .all(|v| matches!(v, 0 | 1)) + { + sum += self.evaluations[point]; + } + } + + sum + } + + /// evaluates the polynomial at an arbitrary point, not neccessarily in {0,1,2}^n_variables. + /// + /// We assert that point.n_variables() == self.n_variables + pub fn evaluate_at_point(&self, point: &MultilinearPoint) -> F { + assert!(point.n_variables() == self.n_variables); + let num_evaluation_points = 3_usize.pow(self.n_variables as u32); + + let mut evaluation = F::ZERO; + + for index in 0..num_evaluation_points { + evaluation += self.evaluations[index] * eq_poly3(point, index); + } + + evaluation + } +} + +#[cfg(test)] +mod tests { + use crate::{crypto::fields::Field64, poly_utils::MultilinearPoint, utils::base_decomposition}; + + use super::SumcheckPolynomial; + + type F = Field64; + + #[test] + fn test_evaluation() { + let num_variables = 2; + + let num_evaluation_points = 3_usize.pow(num_variables as u32); + let evaluations = (0..num_evaluation_points as u64).map(F::from).collect(); + + let poly = SumcheckPolynomial::new(evaluations, num_variables); + + for i in 0..num_evaluation_points { + let decomp = base_decomposition(i, 3, num_variables); + let point = MultilinearPoint(decomp.into_iter().map(F::from).collect()); + assert_eq!(poly.evaluate_at_point(&point), poly.evaluations()[i]); + } + } +} diff --git a/whir/src/sumcheck/prover_batched.rs b/whir/src/sumcheck/prover_batched.rs new file mode 100644 index 000000000..49ee27bf3 --- /dev/null +++ b/whir/src/sumcheck/prover_batched.rs @@ -0,0 +1,307 @@ +use super::proof::SumcheckPolynomial; +use crate::{ + poly_utils::{MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList}, + sumcheck::prover_single::SumcheckSingle, +}; +use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::{join, prelude::*}; + +pub struct SumcheckBatched { + // The evaluation on each p and eq + evaluations_of_p: Vec>, + evaluations_of_equality: Vec>, + comb_coeff: Vec, + num_polys: usize, + num_variables: usize, + sum: F, +} + +impl SumcheckBatched +where + F: Field, +{ + // Input includes the following: + // coeffs: coefficient of a list of polynomials p + // points: one point per poly + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p0(..) * eq0(..) + p1(..) * eq1(..) + ... + pub fn new( + coeffs: Vec>, + points: &[MultilinearPoint], + poly_comb_coeff: &[F], // random coefficients for combining each poly + evals: &[F], + ) -> Self { + let num_polys = coeffs.len(); + assert_eq!(poly_comb_coeff.len(), num_polys); + assert_eq!(points.len(), num_polys); + assert_eq!(evals.len(), num_polys); + let num_variables = coeffs[0].num_variables(); + + let mut prover = SumcheckBatched { + evaluations_of_p: coeffs.into_iter().map(|c| c.into()).collect(), + evaluations_of_equality: vec![ + EvaluationsList::new(vec![F::ZERO; 1 << num_variables]); + num_polys + ], + comb_coeff: poly_comb_coeff.to_vec(), + num_polys, + num_variables, + sum: F::ZERO, + }; + + // Eval points + for (i, point) in points.iter().enumerate() { + SumcheckSingle::eval_eq( + &point.0, + prover.evaluations_of_equality[i].evals_mut(), + F::from(1), + ); + prover.sum += poly_comb_coeff[i] * evals[i]; + } + prover + } + + pub fn get_folded_polys(&self) -> Vec { + self.evaluations_of_p + .iter() + .map(|e| { + assert_eq!(e.num_variables(), 0); + e.evals()[0] + }) + .collect() + } + + pub fn get_folded_eqs(&self) -> Vec { + self.evaluations_of_equality + .iter() + .map(|e| { + assert_eq!(e.num_variables(), 0); + e.evals()[0] + }) + .collect() + } + + #[cfg(not(feature = "parallel"))] + pub fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + panic!("Non-parallel version not supported!"); + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let eval_p_iter = self.evaluation_of_p.evals().chunks_exact(2); + let eval_eq_iter = self.evaluation_of_equality.evals().chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce(|(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) + .unwrap_or((F::ZERO, F::ZERO)); + + // Use the fact that self.sum = p(0) + p(1) = 2 * c0 + c1 + c2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + #[cfg(feature = "parallel")] + pub fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let (_, c0, c2) = self + .comb_coeff + .par_iter() + .zip(&self.evaluations_of_p) + .zip(&self.evaluations_of_equality) + .map(|((rand, eval_p), eval_eq)| { + let eval_p_iter = eval_p.evals().par_chunks_exact(2); + let eval_eq_iter = eval_eq.evals().par_chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce( + || (F::ZERO, F::ZERO), + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ); + (*rand, c0, c2) + }) + .reduce( + || (F::ONE, F::ZERO, F::ZERO), + |(r0, a0, a2), (r1, b0, b2)| (F::ONE, r0 * a0 + r1 * b0, r0 * a2 + r1 * b2), + ); + + // Use the fact that self.sum = p(0) + p(1) = 2 * coeff_0 + coeff_1 + coeff_2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + // When the folding randomness arrives, compress the table accordingly (adding the new points) + #[cfg(not(feature = "parallel"))] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + panic!("Non-parallel version not supported!"); + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let evaluations_of_p = self + .evaluation_of_p + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + let evaluations_of_eq = self + .evaluation_of_equality + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + + // Update + self.num_variables -= 1; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } + + #[cfg(feature = "parallel")] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let evaluations: Vec<_> = self + .evaluations_of_p + .par_iter() + .zip(&self.evaluations_of_equality) + .map(|(eval_p, eval_eq)| { + let (evaluation_of_p, evaluation_of_eq) = join( + || { + eval_p + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + || { + eval_eq + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + ); + ( + EvaluationsList::new(evaluation_of_p), + EvaluationsList::new(evaluation_of_eq), + ) + }) + .collect(); + let (evaluations_of_p, evaluations_of_eq) = evaluations.into_iter().unzip(); + + // Update + self.num_variables -= 1; + self.evaluations_of_p = evaluations_of_p; + self.evaluations_of_equality = evaluations_of_eq; + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + }; + + use super::SumcheckBatched; + + type F = Field64; + + #[test] + fn test_sumcheck_folding_factor_1() { + let num_rounds = 2; + let eval_points = vec![ + MultilinearPoint(vec![F::from(10), F::from(11)]), + MultilinearPoint(vec![F::from(7), F::from(8)]), + ]; + let polynomials = vec![ + CoefficientList::new(vec![F::from(1), F::from(5), F::from(10), F::from(14)]), + CoefficientList::new(vec![F::from(2), F::from(6), F::from(11), F::from(13)]), + ]; + let poly_comb_coeffs = vec![F::from(2), F::from(3)]; + + let evals: Vec = polynomials + .iter() + .zip(&eval_points) + .map(|(poly, point)| poly.evaluate(point)) + .collect(); + let mut claimed_value: F = evals + .iter() + .zip(&poly_comb_coeffs) + .fold(F::from(0), |sum, (eval, poly_rand)| eval * poly_rand + sum); + + let mut prover = + SumcheckBatched::new(polynomials.clone(), &eval_points, &poly_comb_coeffs, &evals); + let mut comb_randomness_list = Vec::new(); + let mut fold_randomness_list = Vec::new(); + + for _ in 0..num_rounds { + let poly = prover.compute_sumcheck_polynomial(); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly.sum_over_hypercube(), claimed_value); + + let next_comb_randomness = F::from(100101); + let next_fold_randomness = MultilinearPoint(vec![F::from(4999)]); + + prover.compress(next_comb_randomness, &next_fold_randomness, &poly); + claimed_value = next_comb_randomness * poly.evaluate_at_point(&next_fold_randomness); + + comb_randomness_list.push(next_comb_randomness); + fold_randomness_list.extend(next_fold_randomness.0); + } + println!("CLAIM:"); + for poly in prover.evaluations_of_p { + println!("POLY: {:?}", poly); + } + println!("EXPECTED:"); + let fold_randomness_list = MultilinearPoint(fold_randomness_list); + for poly in polynomials { + println!("EVAL: {:?}", poly.evaluate(&fold_randomness_list)); + } + } +} diff --git a/whir/src/sumcheck/prover_core.rs b/whir/src/sumcheck/prover_core.rs new file mode 100644 index 000000000..9ecdfeeea --- /dev/null +++ b/whir/src/sumcheck/prover_core.rs @@ -0,0 +1,151 @@ +use ark_ff::Field; + +use crate::{ + poly_utils::{ + MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList, + sequential_lag_poly::LagrangePolynomialIterator, + }, + utils::base_decomposition, +}; + +use super::proof::SumcheckPolynomial; + +pub struct SumcheckCore { + // The evaluation of p + evaluation_of_p: EvaluationsList, + evaluation_of_equality: EvaluationsList, + num_variables: usize, +} + +impl SumcheckCore +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: CoefficientList, // multilinear polynomial in n variables + points: &[MultilinearPoint], // list of points, each of length n. + combination_randomness: &[F], + ) -> Self { + assert_eq!(points.len(), combination_randomness.len()); + let num_variables = coeffs.num_variables(); + + let mut prover = SumcheckCore { + evaluation_of_p: coeffs.into(), // transform coefficient form -> evaluation form + evaluation_of_equality: EvaluationsList::new(vec![F::ZERO; 1 << num_variables]), + num_variables, + }; + + prover.add_new_equality(points, combination_randomness); + prover + } + + pub fn compute_sumcheck_polynomial(&self, folding_factor: usize) -> SumcheckPolynomial { + let two = F::ONE + F::ONE; // Enlightening + + assert!(self.num_variables >= folding_factor); + + let num_evaluation_points = 3_usize.pow(folding_factor as u32); + let suffix_len = 1 << folding_factor; + let prefix_len = (1 << self.num_variables) / suffix_len; + + // sets evaluation_points to the set of all {0,1,2}^folding_factor + let evaluation_points: Vec<_> = (0..num_evaluation_points) + .map(|point| { + MultilinearPoint( + base_decomposition(point, 3, folding_factor) + .into_iter() + .map(|v| match v { + 0 => F::ZERO, + 1 => F::ONE, + 2 => two, + _ => unreachable!(), + }) + .collect(), + ) + }) + .collect(); + let mut evaluations = vec![F::ZERO; num_evaluation_points]; + + // NOTE: This can probably be optimised a fair bit, there are a bunch of lagranges that can + // be computed at the same time, some allocations to save ecc ecc. + for beta_prefix in 0..prefix_len { + // Gather the evaluations that we are concerned about + let indexes: Vec<_> = (0..suffix_len) + .map(|beta_suffix| suffix_len * beta_prefix + beta_suffix) + .collect(); + let left_poly = + EvaluationsList::new(indexes.iter().map(|&i| self.evaluation_of_p[i]).collect()); + let right_poly = EvaluationsList::new( + indexes + .iter() + .map(|&i| self.evaluation_of_equality[i]) + .collect(), + ); + + // For each evaluation point, update with the right added + for point in 0..num_evaluation_points { + evaluations[point] += left_poly.evaluate(&evaluation_points[point]) + * right_poly.evaluate(&evaluation_points[point]); + } + } + + SumcheckPolynomial::new(evaluations, folding_factor) + } + + pub fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + combination_randomness: &[F], + ) { + assert_eq!(combination_randomness.len(), points.len()); + for (point, rand) in points.iter().zip(combination_randomness) { + for (prefix, lag) in LagrangePolynomialIterator::new(point) { + self.evaluation_of_equality.evals_mut()[prefix.0] += *rand * lag; + } + } + } + + // When the folding randomness arrives, compress the table accordingly (adding the new points) + pub fn compress( + &mut self, + folding_factor: usize, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + ) { + assert_eq!(folding_randomness.n_variables(), folding_factor); + assert!(self.num_variables >= folding_factor); + + let suffix_len = 1 << folding_factor; + let prefix_len = (1 << self.num_variables) / suffix_len; + let mut evaluations_of_p = Vec::with_capacity(prefix_len); + let mut evaluations_of_eq = Vec::with_capacity(prefix_len); + + // Compress the table + for beta_prefix in 0..prefix_len { + let indexes: Vec<_> = (0..suffix_len) + .map(|beta_suffix| suffix_len * beta_prefix + beta_suffix) + .collect(); + + let left_poly = + EvaluationsList::new(indexes.iter().map(|&i| self.evaluation_of_p[i]).collect()); + let right_poly = EvaluationsList::new( + indexes + .iter() + .map(|&i| self.evaluation_of_equality[i]) + .collect(), + ); + + evaluations_of_p.push(left_poly.evaluate(folding_randomness)); + evaluations_of_eq + .push(combination_randomness * right_poly.evaluate(folding_randomness)); + } + + // Update + self.num_variables -= folding_factor; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + } +} diff --git a/whir/src/sumcheck/prover_not_skipping.rs b/whir/src/sumcheck/prover_not_skipping.rs new file mode 100644 index 000000000..602e9ab1c --- /dev/null +++ b/whir/src/sumcheck/prover_not_skipping.rs @@ -0,0 +1,329 @@ +use ark_ff::Field; +use nimue::{ + ProofResult, + plugins::ark::{FieldChallenges, FieldIOPattern, FieldWriter}, +}; +use nimue_pow::{PoWChallenge, PowStrategy}; + +use crate::{ + fs_utils::WhirPoWIOPattern, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, +}; + +use super::prover_single::SumcheckSingle; + +pub trait SumcheckNotSkippingIOPattern { + fn add_sumcheck(self, folding_factor: usize, pow_bits: f64) -> Self; +} + +impl SumcheckNotSkippingIOPattern for IOPattern +where + F: Field, + IOPattern: FieldIOPattern + WhirPoWIOPattern, +{ + fn add_sumcheck(mut self, folding_factor: usize, pow_bits: f64) -> Self { + for _ in 0..folding_factor { + self = self + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .pow(pow_bits); + } + self + } +} + +pub struct SumcheckProverNotSkipping { + sumcheck_prover: SumcheckSingle, +} + +impl SumcheckProverNotSkipping +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: CoefficientList, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) -> Self { + Self { + sumcheck_prover: SumcheckSingle::new( + coeffs, + points, + combination_randomness, + evaluations, + ), + } + } + + pub fn compute_sumcheck_polynomials( + &mut self, + merlin: &mut Merlin, + folding_factor: usize, + pow_bits: f64, + ) -> ProofResult> + where + S: PowStrategy, + Merlin: FieldChallenges + FieldWriter + PoWChallenge, + { + let mut res = Vec::with_capacity(folding_factor); + + for _ in 0..folding_factor { + let sumcheck_poly = self.sumcheck_prover.compute_sumcheck_polynomial(); + merlin.add_scalars(sumcheck_poly.evaluations())?; + let [folding_randomness]: [F; 1] = merlin.challenge_scalars()?; + res.push(folding_randomness); + + // Do PoW if needed + if pow_bits > 0. { + merlin.challenge_pow::(pow_bits)?; + } + + self.sumcheck_prover + .compress(F::ONE, &folding_randomness.into(), &sumcheck_poly); + } + + res.reverse(); + Ok(MultilinearPoint(res)) + } + + pub fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) { + self.sumcheck_prover + .add_new_equality(points, combination_randomness, evaluations) + } +} + +#[cfg(test)] +mod tests { + use ark_ff::Field; + use nimue::{ + IOPattern, Merlin, ProofResult, + plugins::ark::{FieldChallenges, FieldIOPattern, FieldReader}, + }; + use nimue_pow::blake3::Blake3PoW; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + sumcheck::{proof::SumcheckPolynomial, prover_not_skipping::SumcheckProverNotSkipping}, + }; + + type F = Field64; + + #[test] + fn test_e2e_short() -> ProofResult<()> { + let num_variables = 2; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + + fn add_sumcheck_io_pattern() -> IOPattern + where + F: Field, + IOPattern: FieldIOPattern, + { + IOPattern::new("test") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + } + + let iopattern = add_sumcheck_io_pattern::(); + + // Prover part + let mut merlin = iopattern.to_merlin(); + let mut prover = SumcheckProverNotSkipping::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + &[ + polynomial.evaluate_at_extension(&ood_point), + polynomial.evaluate_at_extension(&statement_point), + ], + ); + + let folding_randomness_1 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + // Compute the answers + let folded_poly_1 = polynomial.fold(&folding_randomness_1); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + + // Verifier part + let mut arthur = iopattern.to_arthur(merlin.transcript()); + let sumcheck_poly_11: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_11 = SumcheckPolynomial::new(sumcheck_poly_11.to_vec(), 1); + let [folding_randomness_11]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_12: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_12 = SumcheckPolynomial::new(sumcheck_poly_12.to_vec(), 1); + let [folding_randomness_12]: [F; 1] = arthur.challenge_scalars()?; + + assert_eq!( + sumcheck_poly_11.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + assert_eq!( + sumcheck_poly_12.sum_over_hypercube(), + sumcheck_poly_11.evaluate_at_point(&folding_randomness_11.into()) + ); + + let full_folding = MultilinearPoint(vec![folding_randomness_12, folding_randomness_11]); + + let eval_coeff = folded_poly_1.coeffs()[0]; + assert_eq!( + sumcheck_poly_12.evaluate_at_point(&folding_randomness_12.into()), + eval_coeff + * (epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + ); + + Ok(()) + } + + #[test] + fn test_e2e() -> ProofResult<()> { + let num_variables = 4; + let folding_factor = 2; + let polynomial = CoefficientList::new((0..1 << num_variables).map(F::from).collect()); + + // Initial stuff + let ood_point = MultilinearPoint::expand_from_univariate(F::from(42), num_variables); + let statement_point = MultilinearPoint::expand_from_univariate(F::from(97), num_variables); + + // All the randomness + let [epsilon_1, epsilon_2] = [F::from(15), F::from(32)]; + let fold_point = MultilinearPoint(vec![F::from(31), F::from(15)]); + let combination_randomness = vec![F::from(1000)]; + + fn add_sumcheck_io_pattern() -> IOPattern + where + F: Field, + IOPattern: FieldIOPattern, + { + IOPattern::new("test") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + } + + let iopattern = add_sumcheck_io_pattern::(); + + // Prover part + let mut merlin = iopattern.to_merlin(); + let mut prover = SumcheckProverNotSkipping::new( + polynomial.clone(), + &[ood_point.clone(), statement_point.clone()], + &[epsilon_1, epsilon_2], + &[ + polynomial.evaluate_at_extension(&ood_point), + polynomial.evaluate_at_extension(&statement_point), + ], + ); + + let folding_randomness_1 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + let folded_poly_1 = polynomial.fold(&folding_randomness_1); + let fold_eval = folded_poly_1.evaluate_at_extension(&fold_point); + prover.add_new_equality(&[fold_point.clone()], &combination_randomness, &[fold_eval]); + + let folding_randomness_2 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + // Compute the answers + let folded_poly_1 = polynomial.fold(&folding_randomness_1); + let folded_poly_2 = folded_poly_1.fold(&folding_randomness_2); + + let ood_answer = polynomial.evaluate(&ood_point); + let statement_answer = polynomial.evaluate(&statement_point); + let fold_answer = folded_poly_1.evaluate(&fold_point); + + // Verifier part + let mut arthur = iopattern.to_arthur(merlin.transcript()); + let sumcheck_poly_11: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_11 = SumcheckPolynomial::new(sumcheck_poly_11.to_vec(), 1); + let [folding_randomness_11]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_12: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_12 = SumcheckPolynomial::new(sumcheck_poly_12.to_vec(), 1); + let [folding_randomness_12]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_21: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_21 = SumcheckPolynomial::new(sumcheck_poly_21.to_vec(), 1); + let [folding_randomness_21]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_22: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_22 = SumcheckPolynomial::new(sumcheck_poly_22.to_vec(), 1); + let [folding_randomness_22]: [F; 1] = arthur.challenge_scalars()?; + + assert_eq!( + sumcheck_poly_11.sum_over_hypercube(), + epsilon_1 * ood_answer + epsilon_2 * statement_answer + ); + + assert_eq!( + sumcheck_poly_12.sum_over_hypercube(), + sumcheck_poly_11.evaluate_at_point(&folding_randomness_11.into()) + ); + + assert_eq!( + sumcheck_poly_21.sum_over_hypercube(), + sumcheck_poly_12.evaluate_at_point(&folding_randomness_12.into()) + + combination_randomness[0] * fold_answer + ); + + assert_eq!( + sumcheck_poly_22.sum_over_hypercube(), + sumcheck_poly_21.evaluate_at_point(&folding_randomness_21.into()) + ); + + let full_folding = MultilinearPoint(vec![ + folding_randomness_22, + folding_randomness_21, + folding_randomness_12, + folding_randomness_11, + ]); + + let partial_folding = MultilinearPoint(vec![folding_randomness_22, folding_randomness_21]); + + let eval_coeff = folded_poly_2.coeffs()[0]; + assert_eq!( + sumcheck_poly_22.evaluate_at_point(&folding_randomness_22.into()), + eval_coeff + * ((epsilon_1 * eq_poly_outside(&full_folding, &ood_point) + + epsilon_2 * eq_poly_outside(&full_folding, &statement_point)) + + combination_randomness[0] * eq_poly_outside(&partial_folding, &fold_point)) + ); + + Ok(()) + } +} diff --git a/whir/src/sumcheck/prover_not_skipping_batched.rs b/whir/src/sumcheck/prover_not_skipping_batched.rs new file mode 100644 index 000000000..5c5e85770 --- /dev/null +++ b/whir/src/sumcheck/prover_not_skipping_batched.rs @@ -0,0 +1,186 @@ +use ark_ff::Field; +use nimue::{ + ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::{PoWChallenge, PowStrategy}; + +use crate::poly_utils::{MultilinearPoint, coeffs::CoefficientList}; + +use super::prover_batched::SumcheckBatched; + +pub struct SumcheckProverNotSkippingBatched { + sumcheck_prover: SumcheckBatched, +} + +impl SumcheckProverNotSkippingBatched +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: Vec>, + points: &[MultilinearPoint], + poly_comb_coeff: &[F], // random coefficients for combining each poly + evals: &[F], + ) -> Self { + Self { + sumcheck_prover: SumcheckBatched::new(coeffs, points, poly_comb_coeff, evals), + } + } + + pub fn get_folded_polys(&self) -> Vec { + self.sumcheck_prover.get_folded_polys() + } + + pub fn _get_folded_eqs(&self) -> Vec { + self.sumcheck_prover.get_folded_eqs() + } + + pub fn compute_sumcheck_polynomials( + &mut self, + merlin: &mut Merlin, + folding_factor: usize, + pow_bits: f64, + ) -> ProofResult> + where + S: PowStrategy, + Merlin: FieldChallenges + FieldWriter + PoWChallenge, + { + let mut res = Vec::with_capacity(folding_factor); + + for _ in 0..folding_factor { + let sumcheck_poly = self.sumcheck_prover.compute_sumcheck_polynomial(); + merlin.add_scalars(sumcheck_poly.evaluations())?; + let [folding_randomness]: [F; 1] = merlin.challenge_scalars()?; + res.push(folding_randomness); + + // Do PoW if needed + if pow_bits > 0. { + merlin.challenge_pow::(pow_bits)?; + } + + self.sumcheck_prover + .compress(F::ONE, &folding_randomness.into(), &sumcheck_poly); + } + + res.reverse(); + Ok(MultilinearPoint(res)) + } +} + +#[cfg(test)] +mod tests { + use ark_ff::Field; + use nimue::{ + IOPattern, Merlin, ProofResult, + plugins::ark::{FieldChallenges, FieldIOPattern, FieldReader}, + }; + use nimue_pow::blake3::Blake3PoW; + + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + sumcheck::{ + proof::SumcheckPolynomial, + prover_not_skipping_batched::SumcheckProverNotSkippingBatched, + }, + }; + + type F = Field64; + + #[test] + fn test_e2e_short() -> ProofResult<()> { + let num_variables = 2; + let folding_factor = 2; + let polynomials = vec![ + CoefficientList::new((0..1 << num_variables).map(F::from).collect()), + CoefficientList::new((1..(1 << num_variables) + 1).map(F::from).collect()), + ]; + + // Initial stuff + let statement_points = vec![ + MultilinearPoint::expand_from_univariate(F::from(97), num_variables), + MultilinearPoint::expand_from_univariate(F::from(75), num_variables), + ]; + + // Poly randomness + let [alpha_1, alpha_2] = [F::from(15), F::from(32)]; + + fn add_sumcheck_io_pattern() -> IOPattern + where + F: Field, + IOPattern: FieldIOPattern, + { + IOPattern::new("test") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + .add_scalars(3, "sumcheck_poly") + .challenge_scalars(1, "folding_randomness") + } + + let iopattern = add_sumcheck_io_pattern::(); + + // Prover part + let mut merlin = iopattern.to_merlin(); + let mut prover = SumcheckProverNotSkippingBatched::new( + polynomials.clone(), + &statement_points, + &[alpha_1, alpha_2], + &[ + polynomials[0].evaluate_at_extension(&statement_points[0]), + polynomials[1].evaluate_at_extension(&statement_points[1]), + ], + ); + + let folding_randomness_1 = prover.compute_sumcheck_polynomials::( + &mut merlin, + folding_factor, + 0., + )?; + + // Compute the answers + let folded_polys_1: Vec<_> = polynomials + .iter() + .map(|poly| poly.fold(&folding_randomness_1)) + .collect(); + + let statement_answers: Vec = polynomials + .iter() + .zip(&statement_points) + .map(|(poly, point)| poly.evaluate(point)) + .collect(); + + // Verifier part + let mut arthur = iopattern.to_arthur(merlin.transcript()); + let sumcheck_poly_11: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_11 = SumcheckPolynomial::new(sumcheck_poly_11.to_vec(), 1); + let [folding_randomness_11]: [F; 1] = arthur.challenge_scalars()?; + let sumcheck_poly_12: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly_12 = SumcheckPolynomial::new(sumcheck_poly_12.to_vec(), 1); + let [folding_randomness_12]: [F; 1] = arthur.challenge_scalars()?; + + assert_eq!( + sumcheck_poly_11.sum_over_hypercube(), + alpha_1 * statement_answers[0] + alpha_2 * statement_answers[1] + ); + + assert_eq!( + sumcheck_poly_12.sum_over_hypercube(), + sumcheck_poly_11.evaluate_at_point(&folding_randomness_11.into()) + ); + + let full_folding = MultilinearPoint(vec![folding_randomness_12, folding_randomness_11]); + + let eval_coeff = [folded_polys_1[0].coeffs()[0], folded_polys_1[1].coeffs()[0]]; + assert_eq!( + sumcheck_poly_12.evaluate_at_point(&folding_randomness_12.into()), + eval_coeff[0] * alpha_1 * eq_poly_outside(&full_folding, &statement_points[0]) + + eval_coeff[1] * alpha_2 * eq_poly_outside(&full_folding, &statement_points[1]) + ); + + Ok(()) + } +} diff --git a/whir/src/sumcheck/prover_single.rs b/whir/src/sumcheck/prover_single.rs new file mode 100644 index 000000000..5e04ec5ca --- /dev/null +++ b/whir/src/sumcheck/prover_single.rs @@ -0,0 +1,299 @@ +use super::proof::SumcheckPolynomial; +use crate::poly_utils::{MultilinearPoint, coeffs::CoefficientList, evals::EvaluationsList}; +use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::{join, prelude::*}; + +pub struct SumcheckSingle { + // The evaluation of p + evaluation_of_p: EvaluationsList, + evaluation_of_equality: EvaluationsList, + num_variables: usize, + sum: F, +} + +impl SumcheckSingle +where + F: Field, +{ + // Get the coefficient of polynomial p and a list of points + // and initialises the table of the initial polynomial + // v(X_1, ..., X_n) = p(X_1, ... X_n) * (epsilon_1 eq_z_1(X) + epsilon_2 eq_z_2(X) ...) + pub fn new( + coeffs: CoefficientList, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) -> Self { + assert_eq!(points.len(), combination_randomness.len()); + assert_eq!(points.len(), evaluations.len()); + let num_variables = coeffs.num_variables(); + + let mut prover = SumcheckSingle { + evaluation_of_p: coeffs.into(), + evaluation_of_equality: EvaluationsList::new(vec![F::ZERO; 1 << num_variables]), + num_variables, + sum: F::ZERO, + }; + + prover.add_new_equality(points, combination_randomness, evaluations); + prover + } + + #[cfg(not(feature = "parallel"))] + pub(crate) fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let eval_p_iter = self.evaluation_of_p.evals().chunks_exact(2); + let eval_eq_iter = self.evaluation_of_equality.evals().chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce(|(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) + .unwrap_or((F::ZERO, F::ZERO)); + + // Use the fact that self.sum = p(0) + p(1) = 2 * c0 + c1 + c2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + #[cfg(feature = "parallel")] + pub(crate) fn compute_sumcheck_polynomial(&self) -> SumcheckPolynomial { + assert!(self.num_variables >= 1); + + // Compute coefficients of the quadratic result polynomial + let eval_p_iter = self.evaluation_of_p.evals().par_chunks_exact(2); + let eval_eq_iter = self.evaluation_of_equality.evals().par_chunks_exact(2); + let (c0, c2) = eval_p_iter + .zip(eval_eq_iter) + .map(|(p_at, eq_at)| { + // Convert evaluations to coefficients for the linear fns p and eq. + let (p_0, p_1) = (p_at[0], p_at[1] - p_at[0]); + let (eq_0, eq_1) = (eq_at[0], eq_at[1] - eq_at[0]); + + // Now we need to add the contribution of p(x) * eq(x) + (p_0 * eq_0, p_1 * eq_1) + }) + .reduce( + || (F::ZERO, F::ZERO), + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ); + + // Use the fact that self.sum = p(0) + p(1) = 2 * coeff_0 + coeff_1 + coeff_2 + let c1 = self.sum - c0.double() - c2; + + // Evaluate the quadratic polynomial at 0, 1, 2 + let eval_0 = c0; + let eval_1 = c0 + c1 + c2; + let eval_2 = eval_1 + c1 + c2 + c2.double(); + + SumcheckPolynomial::new(vec![eval_0, eval_1, eval_2], 1) + } + + // Evaluate the eq function on for a given point on the hypercube, and add + // the result multiplied by the scalar to the output. + #[cfg(not(feature = "parallel"))] + pub(crate) fn eval_eq(eval: &[F], out: &mut [F], scalar: F) { + debug_assert_eq!(out.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = out.split_at_mut(out.len() / 2); + let s1 = scalar * x; + let s0 = scalar - s1; + Self::eval_eq(tail, low, s0); + Self::eval_eq(tail, high, s1); + } else { + out[0] += scalar; + } + } + + // Evaluate the eq function on a given point on the hypercube, and add + // the result multiplied by the scalar to the output. + #[cfg(feature = "parallel")] + pub(crate) fn eval_eq(eval: &[F], out: &mut [F], scalar: F) { + const PARALLEL_THRESHOLD: usize = 10; + debug_assert_eq!(out.len(), 1 << eval.len()); + if let Some((&x, tail)) = eval.split_first() { + let (low, high) = out.split_at_mut(out.len() / 2); + // Update scalars using a single mul. Note that this causes a data dependency, + // so for small fields it might be better to use two muls. + // This data dependency should go away once we implement parallel point evaluation. + let s1 = scalar * x; + let s0 = scalar - s1; + if tail.len() > PARALLEL_THRESHOLD { + join( + || Self::eval_eq(tail, low, s0), + || Self::eval_eq(tail, high, s1), + ); + } else { + Self::eval_eq(tail, low, s0); + Self::eval_eq(tail, high, s1); + } + } else { + out[0] += scalar; + } + } + + pub fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + combination_randomness: &[F], + evaluations: &[F], + ) { + assert_eq!(combination_randomness.len(), points.len()); + assert_eq!(combination_randomness.len(), evaluations.len()); + for (point, rand) in points.iter().zip(combination_randomness) { + // TODO: We might want to do all points simultaneously so we + // do only a single pass over the data. + Self::eval_eq(&point.0, self.evaluation_of_equality.evals_mut(), *rand); + } + + // Update the sum + for (rand, eval) in combination_randomness.iter().zip(evaluations.iter()) { + self.sum += *rand * eval; + } + } + + // When the folding randomness arrives, compress the table accordingly (adding the new points) + #[cfg(not(feature = "parallel"))] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let evaluations_of_p = self + .evaluation_of_p + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + let evaluations_of_eq = self + .evaluation_of_equality + .evals() + .chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect(); + + // Update + self.num_variables -= 1; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } + + #[cfg(feature = "parallel")] + pub fn compress( + &mut self, + combination_randomness: F, // Scale the initial point + folding_randomness: &MultilinearPoint, + sumcheck_poly: &SumcheckPolynomial, + ) { + assert_eq!(folding_randomness.n_variables(), 1); + assert!(self.num_variables >= 1); + + let randomness = folding_randomness.0[0]; + let (evaluations_of_p, evaluations_of_eq) = join( + || { + self.evaluation_of_p + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + || { + self.evaluation_of_equality + .evals() + .par_chunks_exact(2) + .map(|at| (at[1] - at[0]) * randomness + at[0]) + .collect() + }, + ); + + // Update + self.num_variables -= 1; + self.evaluation_of_p = EvaluationsList::new(evaluations_of_p); + self.evaluation_of_equality = EvaluationsList::new(evaluations_of_eq); + self.sum = combination_randomness * sumcheck_poly.evaluate_at_point(folding_randomness); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + crypto::fields::Field64, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + }; + + use super::SumcheckSingle; + + type F = Field64; + + #[test] + fn test_sumcheck_folding_factor_1() { + let eval_point = MultilinearPoint(vec![F::from(10), F::from(11)]); + let polynomial = + CoefficientList::new(vec![F::from(1), F::from(5), F::from(10), F::from(14)]); + + let claimed_value = polynomial.evaluate(&eval_point); + + let eval = polynomial.evaluate(&eval_point); + let mut prover = SumcheckSingle::new(polynomial, &[eval_point], &[F::from(1)], &[eval]); + + let poly_1 = prover.compute_sumcheck_polynomial(); + + // First, check that is sums to the right value over the hypercube + assert_eq!(poly_1.sum_over_hypercube(), claimed_value); + + let combination_randomness = F::from(100101); + let folding_randomness = MultilinearPoint(vec![F::from(4999)]); + + prover.compress(combination_randomness, &folding_randomness, &poly_1); + + let poly_2 = prover.compute_sumcheck_polynomial(); + + assert_eq!( + poly_2.sum_over_hypercube(), + combination_randomness * poly_1.evaluate_at_point(&folding_randomness) + ); + } +} + +#[test] +fn test_eval_eq() { + use crate::{ + crypto::fields::Field64 as F, poly_utils::sequential_lag_poly::LagrangePolynomialIterator, + }; + use ark_ff::AdditiveGroup; + + let eval = vec![F::from(3), F::from(5)]; + let mut out = vec![F::ZERO; 4]; + SumcheckSingle::eval_eq(&eval, &mut out, F::ONE); + dbg!(&out); + + let point = MultilinearPoint(eval.clone()); + let mut expected = vec![F::ZERO; 4]; + for (prefix, lag) in LagrangePolynomialIterator::new(&point) { + expected[prefix.0] = lag; + } + dbg!(&expected); + + assert_eq!(&out, &expected); +} diff --git a/whir/src/utils.rs b/whir/src/utils.rs new file mode 100644 index 000000000..328553196 --- /dev/null +++ b/whir/src/utils.rs @@ -0,0 +1,152 @@ +use crate::ntt::{transpose, transpose_bench_allocate}; +use ark_ff::Field; +use std::collections::BTreeSet; + +// checks whether the given number n is a power of two. +pub fn is_power_of_two(n: usize) -> bool { + n != 0 && (n & (n - 1) == 0) +} + +/// performs big-endian binary decomposition of `value` and returns the result. +/// +/// `n_bits` must be at must usize::BITS. If it is strictly smaller, the most significant bits of `value` are ignored. +/// The returned vector v ends with the least significant bit of `value` and always has exactly `n_bits` many elements. +pub fn to_binary(value: usize, n_bits: usize) -> Vec { + // Ensure that n is within the bounds of the input integer type + assert!(n_bits <= usize::BITS as usize); + let mut result = vec![false; n_bits]; + for i in 0..n_bits { + result[n_bits - 1 - i] = (value & (1 << i)) != 0; + } + result +} + +// TODO(Gotti): n_bits is a misnomer if base > 2. Should be n_limbs or sth. +// Also, should the behaviour for value >= base^n_bits be specified as part of the API or asserted not to happen? +// Currently, we compute the decomposition of value % (base^n_bits). + +/// decomposes value into its big-endian base-ary decomposition, meaning we return a vector v, s.t. +/// +/// value = v[0]*base^(n_bits-1) + v[1] * base^(n_bits-2) + ... + v[n_bits-1] * 1, +/// where each v[i] is in 0..base. +/// The returned vector always has length exactly n_bits (we pad with leading zeros); +pub fn base_decomposition(value: usize, base: u8, n_bits: usize) -> Vec { + // Initialize the result vector with zeros of the specified length + let mut result = vec![0u8; n_bits]; + + // Create a mutable copy of the value for computation + // Note: We could just make the local passed-by-value argument `value` mutable, but this is clearer. + let mut value = value; + + // Compute the base decomposition + for i in 0..n_bits { + result[n_bits - 1 - i] = (value % (base as usize)) as u8; + value /= base as usize; + } + // TODO: Should we assert!(value == 0) here to check that the orginally passed `value` is < base^n_bits ? + + result +} + +// Gotti: Consider renaming this function. The name sounds like it's a PRG. +// TODO (Gotti): Check that ordering is actually correct at point of use (everything else is big-endian). + +/// expand_randomness outputs the vector [1, base, base^2, base^3, ...] of length len. +pub fn expand_randomness(base: F, len: usize) -> Vec { + let mut res = Vec::with_capacity(len); + let mut acc = F::ONE; + for _ in 0..len { + res.push(acc); + acc *= base; + } + + res +} + +/// Deduplicates AND orders a vector +pub fn dedup(v: impl IntoIterator) -> Vec { + Vec::from_iter(BTreeSet::from_iter(v)) +} + +// FIXME(Gotti): comment does not match what function does (due to mismatch between folding_factor and folding_factor_exp) +// Also, k should be defined: k = evals.len() / 2^{folding_factor}, I guess. + +/// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) +/// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] +pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> Vec { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose(&mut evals, folding_factor_exp, size_of_new_domain); + evals +} + +pub fn stack_evaluations_bench_allocate( + mut evals: Vec, + folding_factor: usize, +) -> Vec { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose_bench_allocate(&mut evals, folding_factor_exp, size_of_new_domain); + evals +} + +#[cfg(test)] +mod tests { + use crate::utils::base_decomposition; + + use super::{is_power_of_two, stack_evaluations, to_binary}; + + #[test] + fn test_evaluations_stack() { + use crate::crypto::fields::Field64 as F; + + let num = 256; + let folding_factor = 3; + let fold_size = 1 << folding_factor; + assert_eq!(num % fold_size, 0); + let evals: Vec<_> = (0..num as u64).map(F::from).collect(); + + let stacked = stack_evaluations(evals, folding_factor); + assert_eq!(stacked.len(), num); + + for (i, fold) in stacked.chunks_exact(fold_size).enumerate() { + assert_eq!(fold.len(), fold_size); + for (j, item) in fold.iter().copied().enumerate().take(fold_size) { + assert_eq!(item, F::from((i + j * num / fold_size) as u64)); + } + } + } + + #[test] + fn test_to_binary() { + assert_eq!(to_binary(0b10111, 5), vec![true, false, true, true, true]); + assert_eq!(to_binary(0b11001, 2), vec![false, true]); // truncate + let empty_vec: Vec = vec![]; // just for the explicit bool type. + assert_eq!(to_binary(1, 0), empty_vec); + assert_eq!(to_binary(0, 0), empty_vec); + } + + #[test] + fn test_is_power_of_two() { + assert!(!is_power_of_two(0)); + assert!(is_power_of_two(1)); + assert!(is_power_of_two(2)); + assert!(!is_power_of_two(3)); + assert!(!is_power_of_two(usize::MAX)); + } + + #[test] + fn test_base_decomposition() { + assert_eq!(base_decomposition(0b1011, 2, 6), vec![0, 0, 1, 0, 1, 1]); + assert_eq!(base_decomposition(15, 3, 3), vec![1, 2, 0]); + // check truncation: This checks the current (undocumented) behaviour (compute modulo base^number_of_limbs) works as believed. + // If we actually specify the API to have a different behaviour, this test should change. + assert_eq!(base_decomposition(15 + 81, 3, 3), vec![1, 2, 0]); + } +} diff --git a/whir/src/whir/batch/committer.rs b/whir/src/whir/batch/committer.rs new file mode 100644 index 000000000..11144ed68 --- /dev/null +++ b/whir/src/whir/batch/committer.rs @@ -0,0 +1,184 @@ +use crate::{ + ntt::expand_from_coeff, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, fold::restructure_evaluations}, + utils, + whir::committer::{Committer, Witness}, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use derive_more::Debug; +use nimue::{ + ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; + +use crate::whir::fs_utils::DigestWriter; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +#[derive(Debug, Clone)] +pub struct Witnesses +where + MerkleConfig: Config, +{ + pub(crate) polys: Vec>, + #[debug(skip)] + pub(crate) merkle_tree: MerkleTree, + pub(crate) merkle_leaves: Vec, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +impl From> for Witnesses { + fn from(witness: Witness) -> Self { + Self { + polys: vec![witness.polynomial], + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers, + } + } +} + +impl From> for Witness { + fn from(witness: Witnesses) -> Self { + Self { + polynomial: witness.polys[0].clone(), + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers, + } + } +} + +impl Committer +where + F: FftField, + MerkleConfig: Config, + PowStrategy: Sync, +{ + pub fn batch_commit( + &self, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> + where + Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, + { + let timer = start_timer!(|| "Batch Commit"); + let base_domain = self.0.starting_domain.base_domain.unwrap(); + let expansion = base_domain.size() / polys[0].num_coeffs(); + let expand_timer = start_timer!(|| "Batch Expand"); + let evals = polys + .par_iter() + .map(|poly| expand_from_coeff(poly.coeffs(), expansion)) + .collect::>>(); + end_timer!(expand_timer); + + assert_eq!(base_domain.size(), evals[0].len()); + + // These stacking operations are bottleneck of the commitment process. + // Try to finish the tasks with as few allocations as possible. + let stack_evaluations_timer = start_timer!(|| "Stack Evaluations"); + let folded_evals = evals + .into_par_iter() + .map(|evals| { + let sub_stack_evaluations_timer = start_timer!(|| "Sub Stack Evaluations"); + let ret = utils::stack_evaluations(evals, self.0.folding_factor.at_round(0)); + end_timer!(sub_stack_evaluations_timer); + ret + }) + .map(|evals| { + let restructure_evaluations_timer = start_timer!(|| "Restructure Evaluations"); + let ret = restructure_evaluations( + evals, + self.0.fold_optimisation, + base_domain.group_gen(), + base_domain.group_gen_inv(), + self.0.folding_factor.at_round(0), + ); + end_timer!(restructure_evaluations_timer); + ret + }) + .flat_map(|evals| evals.into_par_iter().map(F::from_base_prime_field)) + .collect::>(); + end_timer!(stack_evaluations_timer); + + let allocate_timer = start_timer!(|| "Allocate buffer."); + + let mut buffer = Vec::with_capacity(folded_evals.len()); + #[allow(clippy::uninit_vec)] + unsafe { + buffer.set_len(folded_evals.len()); + } + end_timer!(allocate_timer); + let horizontal_stacking_timer = start_timer!(|| "Horizontal Stacking"); + let folded_evals = super::utils::horizontal_stacking( + folded_evals, + base_domain.size(), + self.0.folding_factor.at_round(0), + buffer.as_mut_slice(), + ); + end_timer!(horizontal_stacking_timer); + + // Group folds together as a leaf. + let fold_size = 1 << self.0.folding_factor.at_round(0); + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact(fold_size * polys.len()); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals.par_chunks_exact(fold_size * polys.len()); + + let merkle_build_timer = start_timer!(|| "Build Merkle Tree"); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + end_timer!(merkle_build_timer); + + let root = merkle_tree.root(); + + merlin.add_digest(root)?; + + let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; polys.len() * self.0.committment_ood_samples]; + if self.0.committment_ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_points + .par_iter() + .zip(ood_answers.par_chunks_mut(polys.len())) + .for_each(|(ood_point, ood_answers)| { + for j in 0..polys.len() { + let eval = polys[j].evaluate_at_extension( + &MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + ), + ); + ood_answers[j] = eval; + } + }); + merlin.add_scalars(&ood_answers)?; + } + + let polys = polys + .into_par_iter() + .map(|poly| poly.clone().to_extension()) + .collect::>(); + + end_timer!(timer); + + Ok(Witnesses { + polys, + merkle_tree, + merkle_leaves: folded_evals, + ood_points, + ood_answers, + }) + } +} diff --git a/whir/src/whir/batch/iopattern.rs b/whir/src/whir/batch/iopattern.rs new file mode 100644 index 000000000..13ec986ec --- /dev/null +++ b/whir/src/whir/batch/iopattern.rs @@ -0,0 +1,124 @@ +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use nimue::plugins::ark::*; + +use crate::{ + fs_utils::{OODIOPattern, WhirPoWIOPattern}, + sumcheck::prover_not_skipping::SumcheckNotSkippingIOPattern, + whir::iopattern::DigestIOPattern, +}; + +use crate::whir::parameters::WhirConfig; + +pub trait WhirBatchIOPattern { + fn commit_batch_statement( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self; + fn add_whir_unify_proof( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self; + fn add_whir_batch_proof( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self; +} + +impl WhirBatchIOPattern for IOPattern +where + F: FftField, + MerkleConfig: Config, + IOPattern: ByteIOPattern + + FieldIOPattern + + SumcheckNotSkippingIOPattern + + WhirPoWIOPattern + + OODIOPattern + + DigestIOPattern, +{ + fn commit_batch_statement( + self, + params: &WhirConfig, + batch_size: usize, + ) -> Self { + // TODO: Add params + let mut this = self.add_digest("merkle_digest"); + if params.committment_ood_samples > 0 { + assert!(params.initial_statement); + this = this + .challenge_scalars(params.committment_ood_samples, "ood_query") + .add_scalars(params.committment_ood_samples * batch_size, "ood_ans"); + } + this + } + + fn add_whir_unify_proof( + mut self, + params: &WhirConfig, + batch_size: usize, + ) -> Self { + if batch_size > 1 { + self = self.challenge_scalars(1, "batch_poly_combination_randomness"); + } + self = self + // .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck(params.mv_parameters.num_variables, 0.); + self.add_scalars(batch_size, "unified_folded_evals") + } + + fn add_whir_batch_proof( + mut self, + params: &WhirConfig, + batch_size: usize, + ) -> Self { + if batch_size > 1 { + self = self.challenge_scalars(1, "batch_poly_combination_randomness"); + } + + // TODO: Add statement + if params.initial_statement { + self = self + .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(0), + params.starting_folding_pow_bits, + ); + } else { + self = self + .challenge_scalars(params.folding_factor.at_round(0), "folding_randomness") + .pow(params.starting_folding_pow_bits); + } + + let mut domain_size = params.starting_domain.size(); + + for (round, r) in params.round_parameters.iter().enumerate() { + let folded_domain_size = domain_size >> params.folding_factor.at_round(round); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + self = self + .add_digest("merkle_digest") + .add_ood(r.ood_samples) + .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") + .pow(r.pow_bits) + .challenge_scalars(1, "combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(round + 1), + r.folding_pow_bits, + ); + domain_size >>= 1; + } + + let folded_domain_size = domain_size + >> params + .folding_factor + .at_round(params.round_parameters.len()); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + + self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") + .challenge_bytes(domain_size_bytes * params.final_queries, "final_queries") + .pow(params.final_pow_bits) + .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) + } +} diff --git a/whir/src/whir/batch/mod.rs b/whir/src/whir/batch/mod.rs new file mode 100644 index 000000000..315e9ad92 --- /dev/null +++ b/whir/src/whir/batch/mod.rs @@ -0,0 +1,8 @@ +mod committer; +mod iopattern; +mod prover; +mod utils; +mod verifier; + +pub use committer::Witnesses; +pub use iopattern::WhirBatchIOPattern; diff --git a/whir/src/whir/batch/prover.rs b/whir/src/whir/batch/prover.rs new file mode 100644 index 000000000..04bfa2e46 --- /dev/null +++ b/whir/src/whir/batch/prover.rs @@ -0,0 +1,540 @@ +use super::committer::Witnesses; +use crate::{ + ntt::expand_from_coeff, + parameters::FoldType, + poly_utils::{ + MultilinearPoint, + coeffs::CoefficientList, + fold::{compute_fold, restructure_evaluations}, + }, + sumcheck::{ + prover_not_skipping::SumcheckProverNotSkipping, + prover_not_skipping_batched::SumcheckProverNotSkippingBatched, + }, + utils::{self, expand_randomness}, + whir::{ + WhirProof, + prover::{Prover, RoundState}, + }, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use itertools::zip_eq; +use nimue::{ + ByteChallenges, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::{self, PoWChallenge}; + +use crate::whir::fs_utils::{DigestWriter, get_challenge_stir_queries}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +struct RoundStateBatch<'a, F, MerkleConfig> +where + F: FftField, + MerkleConfig: Config, +{ + round_state: RoundState, + batching_randomness: Vec, + prev_merkle: &'a MerkleTree, + prev_merkle_answers: &'a Vec, +} + +impl Prover +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + fn validate_witnesses(&self, witness: &Witnesses) -> bool { + assert_eq!( + witness.ood_points.len() * witness.polys.len(), + witness.ood_answers.len() + ); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } + assert!(!witness.polys.is_empty(), "Input polys cannot be empty"); + witness.polys.iter().skip(1).for_each(|poly| { + assert_eq!( + poly.num_variables(), + witness.polys[0].num_variables(), + "All polys must have the same number of variables" + ); + }); + witness.polys[0].num_variables() == self.0.mv_parameters.num_variables + } + + /// batch open the same points for multiple polys + pub fn simple_batch_prove( + &self, + merlin: &mut Merlin, + points: &[MultilinearPoint], + evals_per_point: &[Vec], // outer loop on each point, inner loop on each poly + witness: &Witnesses, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + let prove_timer = start_timer!(|| "prove"); + let initial_timer = start_timer!(|| "init"); + assert!(self.0.initial_statement, "must be true for pcs"); + assert!(self.validate_parameters()); + assert!(self.validate_witnesses(witness)); + for point in points { + assert_eq!( + point.0.len(), + self.0.mv_parameters.num_variables, + "number of variables mismatch" + ); + } + let num_polys = witness.polys.len(); + for evals in evals_per_point { + assert_eq!( + evals.len(), + num_polys, + "number of polynomials not equal number of evaluations" + ); + } + + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + end_timer!(initial_timer); + + let random_coeff_timer = start_timer!(|| "random coeff"); + let random_coeff = + super::utils::generate_random_vector_batch_open(merlin, witness.polys.len())?; + end_timer!(random_coeff_timer); + + let initial_claims_timer = start_timer!(|| "initial claims"); + let initial_claims: Vec<_> = witness + .ood_points + .par_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + ) + }) + .chain(points.to_vec()) + .collect(); + end_timer!(initial_claims_timer); + + let ood_answers_timer = start_timer!(|| "ood answers"); + let ood_answers = witness + .ood_answers + .par_chunks_exact(witness.polys.len()) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + end_timer!(ood_answers_timer); + + let eval_timer = start_timer!(|| "eval"); + let eval_per_point: Vec = evals_per_point + .par_iter() + .map(|evals| compute_dot_product(evals, &random_coeff)) + .collect(); + end_timer!(eval_timer); + + let combine_timer = start_timer!(|| "Combine polynomial"); + let initial_answers: Vec<_> = ood_answers.into_iter().chain(eval_per_point).collect(); + + let polynomial = CoefficientList::combine(&witness.polys, &random_coeff); + end_timer!(combine_timer); + + let comb_timer = start_timer!(|| "combination randomness"); + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + end_timer!(comb_timer); + + let sumcheck_timer = start_timer!(|| "sumcheck"); + let mut sumcheck_prover = Some(SumcheckProverNotSkipping::new( + polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + end_timer!(sumcheck_timer); + + let sumcheck_prover_timer = start_timer!(|| "sumcheck_prover"); + let folding_randomness = sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(0), + self.0.starting_folding_pow_bits, + )?; + end_timer!(sumcheck_prover_timer); + + let timer = start_timer!(|| "round_batch"); + let round_state = RoundStateBatch { + round_state: RoundState { + domain: self.0.starting_domain.clone(), + round: 0, + sumcheck_prover, + folding_randomness, + coefficients: polynomial, + prev_merkle: MerkleTree::blank( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + 2, + ) + .unwrap(), + prev_merkle_answers: Vec::new(), + merkle_proofs: vec![], + }, + prev_merkle: &witness.merkle_tree, + prev_merkle_answers: &witness.merkle_leaves, + batching_randomness: random_coeff, + }; + + let result = self.simple_round_batch(merlin, round_state, num_polys); + end_timer!(timer); + end_timer!(prove_timer); + + result + } + + fn simple_round_batch( + &self, + merlin: &mut Merlin, + round_state: RoundStateBatch, + num_polys: usize, + ) -> ProofResult> + where + Merlin: FieldChallenges + + ByteChallenges + + FieldWriter + + ByteWriter + + PoWChallenge + + DigestWriter, + { + let batching_randomness = round_state.batching_randomness; + let prev_merkle = round_state.prev_merkle; + let prev_merkle_answers = round_state.prev_merkle_answers; + let mut round_state = round_state.round_state; + // Fold the coefficients + let folded_coefficients = round_state + .coefficients + .fold(&round_state.folding_randomness); + + let num_variables = self.0.mv_parameters.num_variables + - self.0.folding_factor.total_number(round_state.round); + + // Base case + if round_state.round == self.0.n_rounds() { + // Coefficients of the polynomial + merlin.add_scalars(folded_coefficients.coeffs())?; + + // Final verifier queries and answers + let final_challenge_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor.at_round(round_state.round), + self.0.final_queries, + merlin, + )?; + + let merkle_proof = prev_merkle + .generate_multi_proof(final_challenge_indexes.clone()) + .unwrap(); + let fold_size = 1 << self.0.folding_factor.at_round(round_state.round); + let answers = final_challenge_indexes + .into_par_iter() + .map(|i| { + prev_merkle_answers + [i * (fold_size * num_polys)..(i + 1) * (fold_size * num_polys)] + .to_vec() + }) + .collect(); + + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if self.0.final_pow_bits > 0. { + merlin.challenge_pow::(self.0.final_pow_bits)?; + } + + // Final sumcheck + if self.0.final_sumcheck_rounds > 0 { + round_state + .sumcheck_prover + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new(folded_coefficients.clone(), &[], &[], &[]) + }) + .compute_sumcheck_polynomials::( + merlin, + self.0.final_sumcheck_rounds, + self.0.final_folding_pow_bits, + )?; + } + + return Ok(WhirProof(round_state.merkle_proofs)); + } + + let round_params = &self.0.round_parameters[round_state.round]; + + // Fold the coefficients, and compute fft of polynomial (and commit) + let new_domain = round_state.domain.scale(2); + let expansion = new_domain.size() / folded_coefficients.num_coeffs(); + let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. + let folded_evals = + utils::stack_evaluations(evals, self.0.folding_factor.at_round(round_state.round + 1)); + let folded_evals = restructure_evaluations( + folded_evals, + self.0.fold_optimisation, + new_domain.backing_domain.group_gen(), + new_domain.backing_domain.group_gen_inv(), + self.0.folding_factor.at_round(round_state.round + 1), + ); + + #[cfg(not(feature = "parallel"))] + let leafs_iter = + folded_evals.chunks_exact(1 << self.0.folding_factor.at_round(round_state.round + 1)); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals + .par_chunks_exact(1 << self.0.folding_factor.at_round(round_state.round + 1)); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + + let root = merkle_tree.root(); + merlin.add_digest(root)?; + + // OOD Samples + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = Vec::with_capacity(round_params.ood_samples); + if round_params.ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers.extend(ood_points.iter().map(|ood_point| { + folded_coefficients.evaluate(&MultilinearPoint::expand_from_univariate( + *ood_point, + num_variables, + )) + })); + merlin.add_scalars(&ood_answers)?; + } + + // STIR queries + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor.at_round(round_state.round), + round_params.num_queries, + merlin, + )?; + let domain_scaled_gen = round_state + .domain + .backing_domain + .element(1 << self.0.folding_factor.at_round(round_state.round)); + let stir_challenges: Vec<_> = ood_points + .into_par_iter() + .chain( + stir_challenges_indexes + .par_iter() + .map(|i| domain_scaled_gen.pow([*i as u64])), + ) + .map(|univariate| MultilinearPoint::expand_from_univariate(univariate, num_variables)) + .collect(); + + let merkle_proof = prev_merkle + .generate_multi_proof(stir_challenges_indexes.clone()) + .unwrap(); + let fold_size = (1 << self.0.folding_factor.at_round(round_state.round)) * num_polys; + let answers = stir_challenges_indexes + .par_iter() + .map(|i| prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) + .collect::>(); + let batched_answers = answers + .par_iter() + .map(|answer| { + let chunk_size = 1 << self.0.folding_factor.at_round(round_state.round); + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += answer[i + j * chunk_size] * batching_randomness[j]; + } + } + res + }) + .collect::>(); + // Evaluate answers in the folding randomness. + let mut stir_evaluations = ood_answers.clone(); + match self.0.fold_optimisation { + FoldType::Naive => { + // See `Verifier::compute_folds_full` + let domain_size = round_state.domain.backing_domain.size(); + let domain_gen = round_state.domain.backing_domain.element(1); + let domain_gen_inv = domain_gen.inverse().unwrap(); + let coset_domain_size = 1 << self.0.folding_factor.at_round(round_state.round); + let coset_generator_inv = + domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); + stir_evaluations.extend(stir_challenges_indexes.iter().zip(&batched_answers).map( + |(index, batched_answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = domain_gen_inv.pow([*index as u64]); + + compute_fold( + batched_answers, + &round_state.folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + F::from(2).inverse().unwrap(), + self.0.folding_factor.at_round(round_state.round), + ) + }, + )) + } + FoldType::ProverHelps => { + stir_evaluations.extend(batched_answers.iter().map(|batched_answers| { + CoefficientList::new(batched_answers.to_vec()) + .evaluate(&round_state.folding_randomness) + })) + } + } + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if round_params.pow_bits > 0. { + merlin.challenge_pow::(round_params.pow_bits)?; + } + + // Randomness for combination + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, stir_challenges.len()); + + let mut sumcheck_prover = round_state + .sumcheck_prover + .take() + .map(|mut sumcheck_prover| { + sumcheck_prover.add_new_equality( + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ); + sumcheck_prover + }) + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new( + folded_coefficients.clone(), + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ) + }); + + let folding_randomness = sumcheck_prover + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(round_state.round + 1), + round_params.folding_pow_bits, + )?; + + let round_state = RoundState { + round: round_state.round + 1, + domain: new_domain, + sumcheck_prover: Some(sumcheck_prover), + folding_randomness, + coefficients: folded_coefficients, /* TODO: Is this redundant with `sumcheck_prover.coeff` ? */ + prev_merkle: merkle_tree, + prev_merkle_answers: folded_evals, + merkle_proofs: round_state.merkle_proofs, + }; + + self.round(merlin, round_state) + } +} + +impl Prover +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + /// each poly on a different point, same size + pub fn same_size_batch_prove( + &self, + merlin: &mut Merlin, + point_per_poly: &[MultilinearPoint], + eval_per_poly: &[F], + witness: &Witnesses, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + let prove_timer = start_timer!(|| "prove"); + let initial_timer = start_timer!(|| "init"); + assert!(self.0.initial_statement, "must be true for pcs"); + assert!(self.validate_parameters()); + assert!(self.validate_witnesses(witness)); + for point in point_per_poly { + assert_eq!( + point.0.len(), + self.0.mv_parameters.num_variables, + "number of variables mismatch" + ); + } + let num_polys = witness.polys.len(); + assert_eq!( + eval_per_poly.len(), + num_polys, + "number of polynomials not equal number of evaluations" + ); + end_timer!(initial_timer); + + let poly_comb_randomness_timer = start_timer!(|| "poly comb randomness"); + let poly_comb_randomness = + super::utils::generate_random_vector_batch_open(merlin, witness.polys.len())?; + end_timer!(poly_comb_randomness_timer); + + let initial_claims_timer = start_timer!(|| "initial claims"); + let initial_eval_claims = point_per_poly; + end_timer!(initial_claims_timer); + + let sumcheck_timer = start_timer!(|| "unifying sumcheck"); + let mut sumcheck_prover = SumcheckProverNotSkippingBatched::new( + witness.polys.clone(), + initial_eval_claims, + &poly_comb_randomness, + eval_per_poly, + ); + + // Perform the entire sumcheck + let folded_point = sumcheck_prover.compute_sumcheck_polynomials::( + merlin, + self.0.mv_parameters.num_variables, + 0., + )?; + let folded_evals = sumcheck_prover.get_folded_polys(); + merlin.add_scalars(&folded_evals)?; + end_timer!(sumcheck_timer); + // Problem now reduced to the polys(folded_point) =?= folded_evals + + let timer = start_timer!(|| "simple_batch"); + // perform simple_batch on folded_point and folded_evals + let result = self.simple_batch_prove(merlin, &[folded_point], &[folded_evals], witness)?; + end_timer!(timer); + end_timer!(prove_timer); + + Ok(result) + } +} diff --git a/whir/src/whir/batch/utils.rs b/whir/src/whir/batch/utils.rs new file mode 100644 index 000000000..8b1fae172 --- /dev/null +++ b/whir/src/whir/batch/utils.rs @@ -0,0 +1,101 @@ +use crate::{ + ntt::{transpose, transpose_test}, + utils::expand_randomness, + whir::fs_utils::{DigestReader, DigestWriter}, +}; +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::Field; +use ark_std::{end_timer, start_timer}; +use nimue::{ + ByteReader, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldReader, FieldWriter}, +}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub fn stack_evaluations( + mut evals: Vec, + folding_factor: usize, + buffer: &mut [F], +) -> Vec { + assert!(evals.len() % folding_factor == 0); + let size_of_new_domain = evals.len() / folding_factor; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose_test(&mut evals, folding_factor, size_of_new_domain, buffer); + evals +} + +/// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) +/// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] +/// This function will mutate the function without return +pub fn stack_evaluations_mut(evals: &mut [F], folding_factor: usize) { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose(evals, folding_factor_exp, size_of_new_domain); +} + +/// Takes a vector of matrix and stacking them horizontally +/// Use in-place matrix transposes to avoid data copy +/// each matrix has domain_size elements +/// each matrix has shape (*, 1<( + evals: Vec, + domain_size: usize, + folding_factor: usize, + buffer: &mut [F], +) -> Vec { + let fold_size = 1 << folding_factor; + let num_polys: usize = evals.len() / domain_size; + + let stack_evaluation_timer = start_timer!(|| "Stack Evaluation"); + let mut evals = stack_evaluations(evals, num_polys, buffer); + end_timer!(stack_evaluation_timer); + #[cfg(not(feature = "parallel"))] + let stacked_evals = evals.chunks_exact_mut(fold_size * num_polys); + #[cfg(feature = "parallel")] + let stacked_evals = evals.par_chunks_exact_mut(fold_size * num_polys); + let stack_evaluation_mut_timer = start_timer!(|| "Stack Evaluation Mut"); + stacked_evals.for_each(|eval| stack_evaluations_mut(eval, folding_factor)); + end_timer!(stack_evaluation_mut_timer); + evals +} + +// generate a random vector for batching open +pub fn generate_random_vector_batch_open( + merlin: &mut Merlin, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Merlin: FieldChallenges + FieldWriter + ByteWriter + DigestWriter, +{ + if size == 1 { + return Ok(vec![F::one()]); + } + let [gamma] = merlin.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} + +// generate a random vector for batching verify +pub fn generate_random_vector_batch_verify( + arthur: &mut Arthur, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Arthur: FieldChallenges + FieldReader + ByteReader + DigestReader, +{ + if size == 1 { + return Ok(vec![F::one()]); + } + let [gamma] = arthur.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} diff --git a/whir/src/whir/batch/verifier.rs b/whir/src/whir/batch/verifier.rs new file mode 100644 index 000000000..98bd2ca0e --- /dev/null +++ b/whir/src/whir/batch/verifier.rs @@ -0,0 +1,726 @@ +use std::iter; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{iterable::Iterable, log2}; +use itertools::zip_eq; +use nimue::{ + ByteChallenges, ByteReader, ProofError, ProofResult, + plugins::ark::{FieldChallenges, FieldReader}, +}; +use nimue_pow::{self, PoWChallenge}; + +use crate::{ + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside}, + sumcheck::proof::SumcheckPolynomial, + utils::expand_randomness, + whir::{ + Statement, WhirProof, + fs_utils::{DigestReader, get_challenge_stir_queries}, + verifier::{ParsedCommitment, ParsedProof, ParsedRound, Verifier}, + }, +}; + +impl Verifier +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + // Same multiple points on each polynomial + pub fn simple_batch_verify( + &self, + arthur: &mut Arthur, + num_polys: usize, + points: &[MultilinearPoint], + evals_per_point: &[Vec], + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + for evals in evals_per_point { + assert_eq!(num_polys, evals.len()); + } + + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment_batch(arthur, num_polys)?; + self.batch_verify_internal( + arthur, + num_polys, + points, + evals_per_point, + parsed_commitment, + whir_proof, + ) + } + + // Different points on each polynomial + pub fn same_size_batch_verify( + &self, + arthur: &mut Arthur, + num_polys: usize, + point_per_poly: &[MultilinearPoint], + eval_per_poly: &[F], // evaluations of the polys on individual points + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + assert_eq!(num_polys, point_per_poly.len()); + assert_eq!(num_polys, eval_per_poly.len()); + + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment_batch(arthur, num_polys)?; + + // parse proof + let poly_comb_randomness = + super::utils::generate_random_vector_batch_verify(arthur, num_polys)?; + let (folded_points, folded_evals) = + self.parse_unify_sumcheck(arthur, point_per_poly, poly_comb_randomness)?; + + self.batch_verify_internal( + arthur, + num_polys, + &[folded_points], + &[folded_evals.clone()], + parsed_commitment, + whir_proof, + ) + } + + fn batch_verify_internal( + &self, + arthur: &mut Arthur, + num_polys: usize, + points: &[MultilinearPoint], + evals_per_point: &[Vec], + parsed_commitment: ParsedCommitment, + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + // parse proof + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + + let random_coeff = super::utils::generate_random_vector_batch_verify(arthur, num_polys)?; + let initial_claims: Vec<_> = parsed_commitment + .ood_points + .clone() + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.params.mv_parameters.num_variables, + ) + }) + .chain(points.to_vec()) + .collect(); + + let ood_answers = parsed_commitment + .ood_answers + .clone() + .chunks_exact(num_polys) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + let eval_per_point = evals_per_point + .iter() + .map(|evals| compute_dot_product(evals, &random_coeff)); + + let initial_answers: Vec<_> = ood_answers.into_iter().chain(eval_per_point).collect(); + + let statement = Statement { + points: initial_claims, + evaluations: initial_answers, + }; + let parsed = self.parse_proof_batch( + arthur, + &parsed_commitment, + &statement, + whir_proof, + random_coeff.clone(), + num_polys, + )?; + + let computed_folds = self.compute_folds(&parsed); + + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != statement + .evaluations + .clone() + .into_iter() + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() + { + return Err(ProofError::InvalidProof); + } + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); + } + + for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { + let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); + + let values = round.ood_answers.iter().copied().chain(folds.clone()); + + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + + values + .zip(&round.combination_randomness) + .map(|(val, rand)| val * rand) + .sum::(); + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + // Check the foldings computed from the proof match the evaluations of the polynomial + let final_folds = &computed_folds[computed_folds.len() - 1]; + let final_evaluations = parsed + .final_coefficients + .evaluate_at_univariate(&parsed.final_randomness_points); + if !final_folds + .iter() + .zip(final_evaluations) + .all(|(&fold, eval)| fold == eval) + { + return Err(ProofError::InvalidProof); + } + + // Check the final sumchecks + if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); + let claimed_sum = prev_sumcheck_poly_eval; + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + + // Check the final sumcheck evaluation + let evaluation_of_v_poly = self.compute_v_poly_for_batched(&statement, &parsed); + + if prev_sumcheck_poly_eval + != evaluation_of_v_poly + * parsed + .final_coefficients + .evaluate(&parsed.final_sumcheck_randomness) + { + return Err(ProofError::InvalidProof); + } + + Ok(parsed_commitment.root) + } + + fn parse_commitment_batch( + &self, + arthur: &mut Arthur, + num_polys: usize, + ) -> ProofResult> + where + Arthur: ByteReader + FieldReader + FieldChallenges + DigestReader, + { + let root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples * num_polys]; + if self.params.committment_ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + Ok(ParsedCommitment { + root, + ood_points, + ood_answers, + }) + } + + fn parse_unify_sumcheck( + &self, + arthur: &mut Arthur, + point_per_poly: &[MultilinearPoint], + poly_comb_randomness: Vec, + ) -> ProofResult<(MultilinearPoint, Vec)> + where + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, + { + let num_variables = self.params.mv_parameters.num_variables; + let mut sumcheck_rounds = Vec::new(); + + // Derive combination randomness and first sumcheck polynomial + // let [point_comb_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + // let point_comb_randomness = expand_randomness(point_comb_randomness_gen, num_points); + + // Unifying sumcheck + sumcheck_rounds.reserve_exact(num_variables); + for _ in 0..num_variables { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + let folded_point = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + let folded_eqs: Vec = point_per_poly + .iter() + .zip(&poly_comb_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(point, &folded_point)) + .collect(); + let mut folded_evals = vec![F::ZERO; point_per_poly.len()]; + arthur.fill_next_scalars(&mut folded_evals)?; + let sumcheck_claim = + sumcheck_rounds[num_variables - 1] + .0 + .evaluate_at_point(&MultilinearPoint(vec![ + sumcheck_rounds[num_variables - 1].1, + ])); + let sumcheck_expected: F = folded_evals + .iter() + .zip(&folded_eqs) + .map(|(eval, eq)| *eval * *eq) + .sum(); + if sumcheck_claim != sumcheck_expected { + return Err(ProofError::InvalidProof); + } + + Ok((folded_point, folded_evals)) + } + + fn pow_with_precomputed_squares(squares: &[F], mut index: usize) -> F { + let mut result = F::one(); + let mut i = 0; + while index > 0 { + if index & 1 == 1 { + result *= squares[i]; + } + index >>= 1; + i += 1; + } + result + } + + fn parse_proof_batch( + &self, + arthur: &mut Arthur, + parsed_commitment: &ParsedCommitment, + statement: &Statement, // Will be needed later + whir_proof: &WhirProof, + batched_randomness: Vec, + num_polys: usize, + ) -> ProofResult> + where + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, + { + let mut sumcheck_rounds = Vec::new(); + let mut folding_randomness: MultilinearPoint; + let initial_combination_randomness; + + if self.params.initial_statement { + // Derive combination randomness and first sumcheck polynomial + let [combination_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + initial_combination_randomness = expand_randomness( + combination_randomness_gen, + parsed_commitment.ood_points.len() + statement.points.len(), + ); + + // Initial sumcheck + sumcheck_rounds.reserve_exact(self.params.folding_factor.at_round(0)); + for _ in 0..self.params.folding_factor.at_round(0) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + + folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + } else { + assert_eq!(parsed_commitment.ood_points.len(), 0); + assert_eq!(statement.points.len(), 0); + + initial_combination_randomness = vec![F::ONE]; + + let mut folding_randomness_vec = vec![F::ZERO; self.params.folding_factor.at_round(0)]; + arthur.fill_challenge_scalars(&mut folding_randomness_vec)?; + folding_randomness = MultilinearPoint(folding_randomness_vec); + + // PoW + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + }; + + let mut prev_root = parsed_commitment.root.clone(); + let domain_gen = self.params.starting_domain.backing_domain.group_gen(); + // Precompute the powers of the domain generator, so that + // we can always compute domain_gen.pow(1 << i) by domain_gen_powers[i] + let domain_gen_powers = std::iter::successors(Some(domain_gen), |&curr| Some(curr * curr)) + .take(log2(self.params.starting_domain.size()) as usize) + .collect::>(); + // Since the generator of the domain will be repeatedly squared in + // the future, keep track of the log of the power (i.e., how many times + // it has been squared from domain_gen). + // In another word, always ensure current domain generator = domain_gen_powers[log_based_on_domain_gen] + let mut log_based_on_domain_gen: usize = 0; + let mut domain_gen_inv = self.params.starting_domain.backing_domain.group_gen_inv(); + let mut domain_size = self.params.starting_domain.size(); + let mut rounds = vec![]; + + for r in 0..self.params.n_rounds() { + let (merkle_proof, answers) = &whir_proof.0[r]; + let round_params = &self.params.round_parameters[r]; + + let new_root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = vec![F::ZERO; round_params.ood_samples]; + if round_params.ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + let stir_challenges_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(r), + round_params.num_queries, + arthur, + )?; + + let stir_challenges_points = stir_challenges_indexes + .iter() + .map(|index| { + Self::pow_with_precomputed_squares( + &domain_gen_powers.as_slice() + [log_based_on_domain_gen + self.params.folding_factor.at_round(r)..], + *index, + ) + }) + .collect(); + + if !merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || merkle_proof.leaf_indexes != stir_challenges_indexes + { + return Err(ProofError::InvalidProof); + } + + let answers: Vec<_> = if r == 0 { + answers + .iter() + .map(|raw_answer| { + if !batched_randomness.is_empty() { + let chunk_size = 1 << self.params.folding_factor.at_round(r); + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += + raw_answer[i + j * chunk_size] * batched_randomness[j]; + } + } + res + } else { + raw_answer.clone() + } + }) + .collect() + } else { + answers.to_vec() + }; + + if round_params.pow_bits > 0. { + arthur.challenge_pow::(round_params.pow_bits)?; + } + + let [combination_randomness_gen] = arthur.challenge_scalars()?; + let combination_randomness = expand_randomness( + combination_randomness_gen, + stir_challenges_indexes.len() + round_params.ood_samples, + ); + + let mut sumcheck_rounds = + Vec::with_capacity(self.params.folding_factor.at_round(r + 1)); + for _ in 0..self.params.folding_factor.at_round(r + 1) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if round_params.folding_pow_bits > 0. { + arthur.challenge_pow::(round_params.folding_pow_bits)?; + } + } + + let new_folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + + rounds.push(ParsedRound { + folding_randomness, + ood_points, + ood_answers, + stir_challenges_indexes, + stir_challenges_points, + stir_challenges_answers: answers, + combination_randomness, + sumcheck_rounds, + domain_gen_inv, + }); + + folding_randomness = new_folding_randomness; + + prev_root = new_root.clone(); + log_based_on_domain_gen += 1; + domain_gen_inv = domain_gen_inv * domain_gen_inv; + domain_size >>= 1; + } + + let mut final_coefficients = vec![F::ZERO; 1 << self.params.final_sumcheck_rounds]; + arthur.fill_next_scalars(&mut final_coefficients)?; + let final_coefficients = CoefficientList::new(final_coefficients); + + // Final queries verify + let final_randomness_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(self.params.n_rounds()), + self.params.final_queries, + arthur, + )?; + let final_randomness_points = final_randomness_indexes + .iter() + .map(|index| { + Self::pow_with_precomputed_squares( + &domain_gen_powers.as_slice()[log_based_on_domain_gen + + self.params.folding_factor.at_round(self.params.n_rounds())..], + *index, + ) + }) + .collect(); + + let (final_merkle_proof, final_randomness_answers) = &whir_proof.0[whir_proof.0.len() - 1]; + if !final_merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + final_randomness_answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || final_merkle_proof.leaf_indexes != final_randomness_indexes + { + return Err(ProofError::InvalidProof); + } + + let final_randomness_answers: Vec<_> = if self.params.n_rounds() == 0 { + final_randomness_answers + .iter() + .map(|raw_answer| { + if !batched_randomness.is_empty() { + let chunk_size = + 1 << self.params.folding_factor.at_round(self.params.n_rounds()); + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += raw_answer[i + j * chunk_size] * batched_randomness[j]; + } + } + res + } else { + raw_answer.clone() + } + }) + .collect() + } else { + final_randomness_answers.to_vec() + }; + + if self.params.final_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_pow_bits)?; + } + + let mut final_sumcheck_rounds = Vec::with_capacity(self.params.final_sumcheck_rounds); + for _ in 0..self.params.final_sumcheck_rounds { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + final_sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.final_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_folding_pow_bits)?; + } + } + let final_sumcheck_randomness = MultilinearPoint( + final_sumcheck_rounds + .iter() + .map(|&(_, r)| r) + .rev() + .collect(), + ); + + Ok(ParsedProof { + initial_combination_randomness, + initial_sumcheck_rounds: sumcheck_rounds, + rounds, + final_domain_gen_inv: domain_gen_inv, + final_folding_randomness: folding_randomness, + final_randomness_indexes, + final_randomness_points, + final_randomness_answers: final_randomness_answers.to_vec(), + final_sumcheck_rounds, + final_sumcheck_randomness, + final_coefficients, + }) + } + + /// this is copied and modified from `fn compute_v_poly` + /// to avoid modify the original function for compatibility + fn compute_v_poly_for_batched(&self, statement: &Statement, proof: &ParsedProof) -> F { + let mut num_variables = self.params.mv_parameters.num_variables; + + let mut folding_randomness = MultilinearPoint( + iter::once(&proof.final_sumcheck_randomness.0) + .chain(iter::once(&proof.final_folding_randomness.0)) + .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) + .flatten() + .copied() + .collect(), + ); + + let mut value = statement + .points + .iter() + .zip(&proof.initial_combination_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(point, &folding_randomness)) + .sum(); + + for (round, round_proof) in proof.rounds.iter().enumerate() { + num_variables -= self.params.folding_factor.at_round(round); + folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); + + let ood_points = &round_proof.ood_points; + let stir_challenges_points = &round_proof.stir_challenges_points; + let stir_challenges: Vec<_> = ood_points + .iter() + .chain(stir_challenges_points) + .cloned() + .map(|univariate| { + MultilinearPoint::expand_from_univariate(univariate, num_variables) + // TODO: + // Maybe refactor outside + }) + .collect(); + + let sum_of_claims: F = stir_challenges + .into_iter() + .map(|point| eq_poly_outside(&point, &folding_randomness)) + .zip(&round_proof.combination_randomness) + .map(|(point, rand)| point * rand) + .sum(); + + value += sum_of_claims; + } + + value + } +} diff --git a/whir/src/whir/committer.rs b/whir/src/whir/committer.rs new file mode 100644 index 000000000..3ca14d772 --- /dev/null +++ b/whir/src/whir/committer.rs @@ -0,0 +1,128 @@ +use super::parameters::WhirConfig; +use crate::{ + ntt::expand_from_coeff, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, fold::restructure_evaluations}, + utils, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use derive_more::Debug; +use nimue::{ + ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; + +use crate::whir::fs_utils::DigestWriter; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +#[derive(Clone, Debug)] +pub struct Witness +where + MerkleConfig: Config, +{ + pub(crate) polynomial: CoefficientList, + #[debug(skip)] + pub(crate) merkle_tree: MerkleTree, + pub(crate) merkle_leaves: Vec, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +pub struct Committer( + pub(crate) WhirConfig, +) +where + F: FftField, + MerkleConfig: Config; + +impl Committer +where + F: FftField, + MerkleConfig: Config, +{ + pub fn new(config: WhirConfig) -> Self { + Self(config) + } + + pub fn commit( + &self, + merlin: &mut Merlin, + mut polynomial: CoefficientList, + ) -> ProofResult> + where + Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, + { + let timer = start_timer!(|| "Single Commit"); + // If size of polynomial < folding factor, keep doubling polynomial size by cloning itself + polynomial.pad_to_num_vars(self.0.folding_factor.at_round(0)); + + let base_domain = self.0.starting_domain.base_domain.unwrap(); + let expansion = base_domain.size() / polynomial.num_coeffs(); + let evals = expand_from_coeff(polynomial.coeffs(), expansion); + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. + let folded_evals = utils::stack_evaluations(evals, self.0.folding_factor.at_round(0)); + let folded_evals = restructure_evaluations( + folded_evals, + self.0.fold_optimisation, + base_domain.group_gen(), + base_domain.group_gen_inv(), + self.0.folding_factor.at_round(0), + ); + + // Convert to extension field. + // This is not necessary for the commit, but in further rounds + // we will need the extension field. For symplicity we do it here too. + // TODO: Commit to base field directly. + let folded_evals = folded_evals + .into_iter() + .map(F::from_base_prime_field) + .collect::>(); + + // Group folds together as a leaf. + let fold_size = 1 << self.0.folding_factor.at_round(0); + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact(fold_size); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals.par_chunks_exact(fold_size); + + let merkle_build_timer = start_timer!(|| "Single Merkle Tree Build"); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + end_timer!(merkle_build_timer); + + let root = merkle_tree.root(); + + merlin.add_digest(root)?; + + let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; + let mut ood_answers = Vec::with_capacity(self.0.committment_ood_samples); + if self.0.committment_ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers.extend(ood_points.iter().map(|ood_point| { + polynomial.evaluate_at_extension(&MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + )) + })); + merlin.add_scalars(&ood_answers)?; + } + + end_timer!(timer); + + Ok(Witness { + polynomial: polynomial.to_extension(), + merkle_tree, + merkle_leaves: folded_evals, + ood_points, + ood_answers, + }) + } +} diff --git a/whir/src/whir/fs_utils.rs b/whir/src/whir/fs_utils.rs new file mode 100644 index 000000000..e9a2922e8 --- /dev/null +++ b/whir/src/whir/fs_utils.rs @@ -0,0 +1,39 @@ +use crate::utils::dedup; +use ark_crypto_primitives::merkle_tree::Config; +use nimue::{ByteChallenges, ProofResult}; + +pub fn get_challenge_stir_queries( + domain_size: usize, + folding_factor: usize, + num_queries: usize, + transcript: &mut T, +) -> ProofResult> +where + T: ByteChallenges, +{ + let folded_domain_size = domain_size / (1 << folding_factor); + // How many bytes do we need to represent an index in the folded domain? + // domain_size_bytes = log2(folded_domain_size) / 8 + // (both operations are rounded up) + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + // We need these many bytes to represent the query indices + let mut queries = vec![0u8; num_queries * domain_size_bytes]; + transcript.fill_challenge_bytes(&mut queries)?; + let indices = queries.chunks_exact(domain_size_bytes).map(|chunk| { + let mut result = 0; + for byte in chunk { + result <<= 8; + result |= *byte as usize; + } + result % folded_domain_size + }); + Ok(dedup(indices)) +} + +pub trait DigestWriter { + fn add_digest(&mut self, digest: MerkleConfig::InnerDigest) -> ProofResult<()>; +} + +pub trait DigestReader { + fn read_digest(&mut self) -> ProofResult; +} diff --git a/whir/src/whir/iopattern.rs b/whir/src/whir/iopattern.rs new file mode 100644 index 000000000..a97d5e3b6 --- /dev/null +++ b/whir/src/whir/iopattern.rs @@ -0,0 +1,95 @@ +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use nimue::plugins::ark::*; + +use crate::{ + fs_utils::{OODIOPattern, WhirPoWIOPattern}, + sumcheck::prover_not_skipping::SumcheckNotSkippingIOPattern, +}; + +use super::parameters::WhirConfig; + +pub trait DigestIOPattern { + fn add_digest(self, label: &str) -> Self; +} + +pub trait WhirIOPattern { + fn commit_statement( + self, + params: &WhirConfig, + ) -> Self; + fn add_whir_proof(self, params: &WhirConfig) + -> Self; +} + +impl WhirIOPattern for IOPattern +where + F: FftField, + MerkleConfig: Config, + IOPattern: ByteIOPattern + + FieldIOPattern + + SumcheckNotSkippingIOPattern + + WhirPoWIOPattern + + OODIOPattern + + DigestIOPattern, +{ + fn commit_statement( + self, + params: &WhirConfig, + ) -> Self { + // TODO: Add params + let mut this = self.add_digest("merkle_digest"); + if params.committment_ood_samples > 0 { + assert!(params.initial_statement); + this = this.add_ood(params.committment_ood_samples); + } + this + } + + fn add_whir_proof( + mut self, + params: &WhirConfig, + ) -> Self { + // TODO: Add statement + if params.initial_statement { + self = self + .challenge_scalars(1, "initial_combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(0), + params.starting_folding_pow_bits, + ); + } else { + self = self + .challenge_scalars(params.folding_factor.at_round(0), "folding_randomness") + .pow(params.starting_folding_pow_bits); + } + + let mut domain_size = params.starting_domain.size(); + for (round, r) in params.round_parameters.iter().enumerate() { + let folded_domain_size = domain_size >> params.folding_factor.at_round(round); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + self = self + .add_digest("merkle_digest") + .add_ood(r.ood_samples) + .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") + .pow(r.pow_bits) + .challenge_scalars(1, "combination_randomness") + .add_sumcheck( + params.folding_factor.at_round(round + 1), + r.folding_pow_bits, + ); + domain_size >>= 1; + } + + let folded_domain_size = domain_size + >> params + .folding_factor + .at_round(params.round_parameters.len()); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + + self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") + .challenge_bytes(domain_size_bytes * params.final_queries, "final_queries") + .pow(params.final_pow_bits) + .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) + } +} diff --git a/whir/src/whir/mod.rs b/whir/src/whir/mod.rs new file mode 100644 index 000000000..816ac3c84 --- /dev/null +++ b/whir/src/whir/mod.rs @@ -0,0 +1,369 @@ +use ark_crypto_primitives::merkle_tree::{Config, MultiPath}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +use crate::poly_utils::MultilinearPoint; + +pub mod batch; +pub mod committer; +pub mod fs_utils; +pub mod iopattern; +pub mod parameters; +pub mod prover; +pub mod verifier; + +#[derive(Debug, Clone, Default)] +pub struct Statement { + pub points: Vec>, + pub evaluations: Vec, +} + +// Only includes the authentication paths +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct WhirProof(pub(crate) Vec<(MultiPath, Vec>)>) +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize; + +pub fn whir_proof_size( + transcript: &[u8], + whir_proof: &WhirProof, +) -> usize +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + transcript.len() + whir_proof.serialized_size(ark_serialize::Compress::Yes) +} + +#[cfg(test)] +mod tests { + use nimue::{DefaultHash, IOPattern}; + use nimue_pow::blake3::Blake3PoW; + + use crate::{ + crypto::{fields::Field64, merkle_tree::blake3 as merkle_tree}, + parameters::{ + FoldType, FoldingFactor, MultivariateParameters, SoundnessType, WhirParameters, + }, + poly_utils::{MultilinearPoint, coeffs::CoefficientList}, + whir::{ + Statement, batch::WhirBatchIOPattern, committer::Committer, iopattern::WhirIOPattern, + parameters::WhirConfig, prover::Prover, verifier::Verifier, + }, + }; + + type MerkleConfig = merkle_tree::MerkleTreeParams; + type PowStrategy = Blake3PoW; + type F = Field64; + + fn make_whir_things( + num_variables: usize, + folding_factor: FoldingFactor, + num_points: usize, + soundness_type: SoundnessType, + pow_bits: usize, + fold_type: FoldType, + ) { + let num_coeffs = 1 << num_variables; + + let mut rng = ark_std::test_rng(); + let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor, + leaf_hash_params, + two_to_one_params, + soundness_type, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + fold_optimisation: fold_type, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let polynomial = CoefficientList::new(vec![F::from(1); num_coeffs]); + + let points: Vec<_> = (0..num_points) + .map(|_| MultilinearPoint::rand(&mut rng, num_variables)) + .collect(); + + let statement = Statement { + points: points.clone(), + evaluations: points + .iter() + .map(|point| polynomial.evaluate(point)) + .collect(), + }; + + let io = IOPattern::::new("🌪️") + .commit_statement(¶ms) + .add_whir_proof(¶ms) + .clone(); + + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params.clone()); + let witness = committer.commit(&mut merlin, polynomial).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .prove(&mut merlin, statement.clone(), witness) + .unwrap(); + + let verifier = Verifier::new(params); + let mut arthur = io.to_arthur(merlin.transcript()); + assert!(verifier.verify(&mut arthur, &statement, &proof).is_ok()); + } + + fn make_whir_batch_things_same_point( + num_polynomials: usize, + num_variables: usize, + num_points: usize, + folding_factor: usize, + soundness_type: SoundnessType, + pow_bits: usize, + fold_type: FoldType, + ) { + println!( + "NP = {num_polynomials}, NE = {num_points}, NV = {num_variables}, FOLD_TYPE = {:?}", + fold_type + ); + let num_coeffs = 1 << num_variables; + + let mut rng = ark_std::test_rng(); + let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor: FoldingFactor::Constant(folding_factor), + leaf_hash_params, + two_to_one_params, + soundness_type, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + fold_optimisation: fold_type, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let polynomials: Vec> = (0..num_polynomials) + .map(|i| CoefficientList::new(vec![F::from((i + 1) as i32); num_coeffs])) + .collect(); + + let points: Vec> = (0..num_points) + .map(|_| MultilinearPoint::rand(&mut rng, num_variables)) + .collect(); + let evals_per_point: Vec> = points + .iter() + .map(|point| { + polynomials + .iter() + .map(|poly| poly.evaluate(point)) + .collect() + }) + .collect(); + + let io = IOPattern::::new("🌪️") + .commit_batch_statement(¶ms, num_polynomials) + .add_whir_batch_proof(¶ms, num_polynomials) + .clone(); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params.clone()); + let witnesses = committer.batch_commit(&mut merlin, &polynomials).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .simple_batch_prove(&mut merlin, &points, &evals_per_point, &witnesses) + .unwrap(); + + let verifier = Verifier::new(params); + let mut arthur = io.to_arthur(merlin.transcript()); + assert!( + verifier + .simple_batch_verify( + &mut arthur, + num_polynomials, + &points, + &evals_per_point, + &proof + ) + .is_ok() + ); + println!("PASSED!"); + } + + fn make_whir_batch_things_diff_point( + num_polynomials: usize, + num_variables: usize, + folding_factor: usize, + soundness_type: SoundnessType, + pow_bits: usize, + fold_type: FoldType, + ) { + println!( + "NP = {num_polynomials}, NV = {num_variables}, FOLD_TYPE = {:?}", + fold_type + ); + let num_coeffs = 1 << num_variables; + + let mut rng = ark_std::test_rng(); + let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); + + let mv_params = MultivariateParameters::::new(num_variables); + + let whir_params = WhirParameters:: { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor: FoldingFactor::Constant(folding_factor), + leaf_hash_params, + two_to_one_params, + soundness_type, + _pow_parameters: Default::default(), + starting_log_inv_rate: 1, + fold_optimisation: fold_type, + }; + + let params = WhirConfig::::new(mv_params, whir_params); + + let polynomials: Vec> = (0..num_polynomials) + .map(|i| CoefficientList::new(vec![F::from((i + 1) as i32); num_coeffs])) + .collect(); + + let point_per_poly: Vec> = (0..num_polynomials) + .map(|_| MultilinearPoint::rand(&mut rng, num_variables)) + .collect(); + let eval_per_poly: Vec = polynomials + .iter() + .zip(&point_per_poly) + .map(|(poly, point)| poly.evaluate(point)) + .collect(); + + let io = IOPattern::::new("🌪️") + .commit_batch_statement(¶ms, num_polynomials) + .add_whir_unify_proof(¶ms, num_polynomials) + .add_whir_batch_proof(¶ms, num_polynomials) + .clone(); + let mut merlin = io.to_merlin(); + + let committer = Committer::new(params.clone()); + let witnesses = committer.batch_commit(&mut merlin, &polynomials).unwrap(); + + let prover = Prover(params.clone()); + + let proof = prover + .same_size_batch_prove(&mut merlin, &point_per_poly, &eval_per_poly, &witnesses) + .unwrap(); + + let verifier = Verifier::new(params); + let mut arthur = io.to_arthur(merlin.transcript()); + verifier + .same_size_batch_verify( + &mut arthur, + num_polynomials, + &point_per_poly, + &eval_per_poly, + &proof, + ) + .unwrap(); + // assert!(verifier + // .same_size_batch_verify(&mut arthur, num_polynomials, &point_per_poly, &eval_per_poly, &proof) + // .is_ok()); + println!("PASSED!"); + } + + #[test] + fn test_whir() { + let folding_factors = [2, 3, 4, 5]; + let soundness_type = [ + SoundnessType::ConjectureList, + SoundnessType::ProvableList, + SoundnessType::UniqueDecoding, + ]; + let fold_types = [FoldType::Naive, FoldType::ProverHelps]; + let num_points = [0, 1, 2]; + let num_polys = [1, 2, 3]; + let pow_bits = [0, 5, 10]; + + for folding_factor in folding_factors { + let num_variables = folding_factor - 1..=2 * folding_factor; + for num_variables in num_variables { + for fold_type in fold_types { + for num_points in num_points { + for soundness_type in soundness_type { + for pow_bits in pow_bits { + make_whir_things( + num_variables, + FoldingFactor::Constant(folding_factor), + num_points, + soundness_type, + pow_bits, + fold_type, + ); + } + } + } + } + } + } + + for folding_factor in folding_factors { + let num_variables = folding_factor..=2 * folding_factor; + for num_variables in num_variables { + for fold_type in fold_types { + for num_points in num_points { + for num_polys in num_polys { + for soundness_type in soundness_type { + for pow_bits in pow_bits { + make_whir_batch_things_same_point( + num_polys, + num_variables, + num_points, + folding_factor, + soundness_type, + pow_bits, + fold_type, + ); + } + } + } + } + } + } + } + + for folding_factor in folding_factors { + let num_variables = folding_factor..=2 * folding_factor; + for num_variables in num_variables { + for fold_type in fold_types { + for num_polys in num_polys { + for soundness_type in soundness_type { + for pow_bits in pow_bits { + make_whir_batch_things_diff_point( + num_polys, + num_variables, + folding_factor, + soundness_type, + pow_bits, + fold_type, + ); + } + } + } + } + } + } + } +} diff --git a/whir/src/whir/parameters.rs b/whir/src/whir/parameters.rs new file mode 100644 index 000000000..23a0bcb7b --- /dev/null +++ b/whir/src/whir/parameters.rs @@ -0,0 +1,633 @@ +use core::{fmt, panic}; +use derive_more::Debug; +use std::{f64::consts::LOG2_10, fmt::Display, marker::PhantomData}; + +use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; +use ark_ff::FftField; +use serde::{Deserialize, Serialize}; + +use crate::{ + crypto::fields::FieldWithSize, + domain::Domain, + parameters::{FoldType, FoldingFactor, MultivariateParameters, SoundnessType, WhirParameters}, +}; + +#[derive(Clone, Debug)] +pub struct WhirConfig +where + F: FftField, + MerkleConfig: Config, +{ + pub(crate) mv_parameters: MultivariateParameters, + pub(crate) soundness_type: SoundnessType, + pub(crate) security_level: usize, + pub(crate) max_pow_bits: usize, + + pub(crate) committment_ood_samples: usize, + // The WHIR protocol can prove either: + // 1. The commitment is a valid low degree polynomial. In that case, the + // initial statement is set to false. + // 2. The commitment is a valid folded polynomial, and an additional + // polynomial evaluation statement. In that case, the initial statement + // is set to true. + pub(crate) initial_statement: bool, + pub(crate) starting_domain: Domain, + pub(crate) starting_log_inv_rate: usize, + pub(crate) starting_folding_pow_bits: f64, + + pub(crate) folding_factor: FoldingFactor, + pub(crate) round_parameters: Vec, + pub(crate) fold_optimisation: FoldType, + + pub(crate) final_queries: usize, + pub(crate) final_pow_bits: f64, + pub(crate) final_log_inv_rate: usize, + pub(crate) final_sumcheck_rounds: usize, + pub(crate) final_folding_pow_bits: f64, + + // PoW parameters + pub(crate) pow_strategy: PhantomData, + + // Merkle tree parameters + #[debug(skip)] + pub(crate) leaf_hash_params: LeafParam, + #[debug(skip)] + pub(crate) two_to_one_params: TwoToOneParam, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct RoundConfig { + pub(crate) pow_bits: f64, + pub(crate) folding_pow_bits: f64, + pub(crate) num_queries: usize, + pub(crate) ood_samples: usize, + pub(crate) log_inv_rate: usize, +} + +impl WhirConfig +where + F: FftField + FieldWithSize, + MerkleConfig: Config, +{ + pub fn new( + mut mv_parameters: MultivariateParameters, + whir_parameters: WhirParameters, + ) -> Self { + // Pad the number of variables to folding factor + if mv_parameters.num_variables < whir_parameters.folding_factor.at_round(0) { + mv_parameters.num_variables = whir_parameters.folding_factor.at_round(0); + } + whir_parameters + .folding_factor + .check_validity(mv_parameters.num_variables) + .unwrap(); + + let protocol_security_level = + 0.max(whir_parameters.security_level - whir_parameters.pow_bits); + + let starting_domain = Domain::new( + 1 << mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + ) + .expect("Should have found an appropriate domain - check Field 2 adicity?"); + + let (num_rounds, final_sumcheck_rounds) = whir_parameters + .folding_factor + .compute_number_of_rounds(mv_parameters.num_variables); + + let field_size_bits = F::field_size_in_bits(); + + let committment_ood_samples = if whir_parameters.initial_statement { + Self::ood_samples( + whir_parameters.security_level, + whir_parameters.soundness_type, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + field_size_bits, + ) + } else { + 0 + }; + + let starting_folding_pow_bits = if whir_parameters.initial_statement { + Self::folding_pow_bits( + whir_parameters.security_level, + whir_parameters.soundness_type, + field_size_bits, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + ) + } else { + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + whir_parameters.soundness_type, + field_size_bits, + mv_parameters.num_variables, + whir_parameters.starting_log_inv_rate, + Self::log_eta( + whir_parameters.soundness_type, + whir_parameters.starting_log_inv_rate, + ), + ) + (whir_parameters.folding_factor.at_round(0) as f64).log2(); + 0_f64.max(whir_parameters.security_level as f64 - prox_gaps_error) + }; + + let mut round_parameters = Vec::with_capacity(num_rounds); + let mut num_variables = + mv_parameters.num_variables - whir_parameters.folding_factor.at_round(0); + let mut log_inv_rate = whir_parameters.starting_log_inv_rate; + for round in 0..num_rounds { + // Queries are set w.r.t. to old rate, while the rest to the new rate + let next_rate = log_inv_rate + (whir_parameters.folding_factor.at_round(round) - 1); + + let log_next_eta = Self::log_eta(whir_parameters.soundness_type, next_rate); + let num_queries = Self::queries( + whir_parameters.soundness_type, + protocol_security_level, + log_inv_rate, + ); + + let ood_samples = Self::ood_samples( + whir_parameters.security_level, + whir_parameters.soundness_type, + num_variables, + next_rate, + log_next_eta, + field_size_bits, + ); + + let query_error = + Self::rbr_queries(whir_parameters.soundness_type, log_inv_rate, num_queries); + let combination_error = Self::rbr_soundness_queries_combination( + whir_parameters.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_next_eta, + ood_samples, + num_queries, + ); + + let pow_bits = 0_f64 + .max(whir_parameters.security_level as f64 - (query_error.min(combination_error))); + + let folding_pow_bits = Self::folding_pow_bits( + whir_parameters.security_level, + whir_parameters.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_next_eta, + ); + + round_parameters.push(RoundConfig { + ood_samples, + num_queries, + pow_bits, + folding_pow_bits, + log_inv_rate, + }); + + num_variables -= whir_parameters.folding_factor.at_round(round + 1); + log_inv_rate = next_rate; + } + + let final_queries = Self::queries( + whir_parameters.soundness_type, + protocol_security_level, + log_inv_rate, + ); + + let final_pow_bits = 0_f64.max( + whir_parameters.security_level as f64 + - Self::rbr_queries(whir_parameters.soundness_type, log_inv_rate, final_queries), + ); + + let final_folding_pow_bits = + 0_f64.max(whir_parameters.security_level as f64 - (field_size_bits - 1) as f64); + + WhirConfig { + security_level: whir_parameters.security_level, + max_pow_bits: whir_parameters.pow_bits, + initial_statement: whir_parameters.initial_statement, + committment_ood_samples, + mv_parameters, + starting_domain, + soundness_type: whir_parameters.soundness_type, + starting_log_inv_rate: whir_parameters.starting_log_inv_rate, + starting_folding_pow_bits, + folding_factor: whir_parameters.folding_factor, + round_parameters, + final_queries, + final_pow_bits, + final_sumcheck_rounds, + final_folding_pow_bits, + pow_strategy: PhantomData, + fold_optimisation: whir_parameters.fold_optimisation, + final_log_inv_rate: log_inv_rate, + leaf_hash_params: whir_parameters.leaf_hash_params, + two_to_one_params: whir_parameters.two_to_one_params, + } + } + + pub fn n_rounds(&self) -> usize { + self.round_parameters.len() + } + + pub fn check_pow_bits(&self) -> bool { + [ + self.starting_folding_pow_bits, + self.final_pow_bits, + self.final_folding_pow_bits, + ] + .into_iter() + .all(|x| x <= self.max_pow_bits as f64) + && self.round_parameters.iter().all(|r| { + r.pow_bits <= self.max_pow_bits as f64 + && r.folding_pow_bits <= self.max_pow_bits as f64 + }) + } + + pub fn log_eta(soundness_type: SoundnessType, log_inv_rate: usize) -> f64 { + // Ask me how I did this? At the time, only God and I knew. Now only God knows + match soundness_type { + SoundnessType::ProvableList => -(0.5 * log_inv_rate as f64 + LOG2_10 + 1.), + SoundnessType::UniqueDecoding => 0., + SoundnessType::ConjectureList => -(log_inv_rate as f64 + 1.), + } + } + + pub fn list_size_bits( + soundness_type: SoundnessType, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + match soundness_type { + SoundnessType::ConjectureList => (num_variables + log_inv_rate) as f64 - log_eta, + SoundnessType::ProvableList => { + let log_inv_sqrt_rate: f64 = log_inv_rate as f64 / 2.; + log_inv_sqrt_rate - (1. + log_eta) + } + SoundnessType::UniqueDecoding => 0.0, + } + } + + pub fn rbr_ood_sample( + soundness_type: SoundnessType, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + field_size_bits: usize, + ood_samples: usize, + ) -> f64 { + let list_size_bits = + Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); + + let error = 2. * list_size_bits + (num_variables * ood_samples) as f64; + (ood_samples * field_size_bits) as f64 + 1. - error + } + + pub fn ood_samples( + security_level: usize, // We don't do PoW for OOD + soundness_type: SoundnessType, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + field_size_bits: usize, + ) -> usize { + if matches!(soundness_type, SoundnessType::UniqueDecoding) { + 0 + } else { + for ood_samples in 1..64 { + if Self::rbr_ood_sample( + soundness_type, + num_variables, + log_inv_rate, + log_eta, + field_size_bits, + ood_samples, + ) >= security_level as f64 + { + return ood_samples; + } + } + + panic!("Could not find an appropriate number of OOD samples"); + } + } + + // Compute the proximity gaps term of the fold + pub fn rbr_soundness_fold_prox_gaps( + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + // Recall, at each round we are only folding by two at a time + let error = match soundness_type { + SoundnessType::ConjectureList => (num_variables + log_inv_rate) as f64 - log_eta, + SoundnessType::ProvableList => { + LOG2_10 + 3.5 * log_inv_rate as f64 + 2. * num_variables as f64 + } + SoundnessType::UniqueDecoding => (num_variables + log_inv_rate) as f64, + }; + + field_size_bits as f64 - error + } + + pub fn rbr_soundness_fold_sumcheck( + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + let list_size = Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); + + field_size_bits as f64 - (list_size + 1.) + } + + pub fn folding_pow_bits( + security_level: usize, + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ) -> f64 { + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + soundness_type, + field_size_bits, + num_variables, + log_inv_rate, + log_eta, + ); + let sumcheck_error = Self::rbr_soundness_fold_sumcheck( + soundness_type, + field_size_bits, + num_variables, + log_inv_rate, + log_eta, + ); + + let error = prox_gaps_error.min(sumcheck_error); + + 0_f64.max(security_level as f64 - error) + } + + // Used to select the number of queries + pub fn queries( + soundness_type: SoundnessType, + protocol_security_level: usize, + log_inv_rate: usize, + ) -> usize { + let num_queries_f = match soundness_type { + SoundnessType::UniqueDecoding => { + let rate = 1. / ((1 << log_inv_rate) as f64); + let denom = (0.5 * (1. + rate)).log2(); + + -(protocol_security_level as f64) / denom + } + SoundnessType::ProvableList => { + (2 * protocol_security_level) as f64 / log_inv_rate as f64 + } + SoundnessType::ConjectureList => protocol_security_level as f64 / log_inv_rate as f64, + }; + num_queries_f.ceil() as usize + } + + // This is the bits of security of the query step + pub fn rbr_queries( + soundness_type: SoundnessType, + log_inv_rate: usize, + num_queries: usize, + ) -> f64 { + let num_queries = num_queries as f64; + + match soundness_type { + SoundnessType::UniqueDecoding => { + let rate = 1. / ((1 << log_inv_rate) as f64); + let denom = -(0.5 * (1. + rate)).log2(); + + num_queries * denom + } + SoundnessType::ProvableList => num_queries * 0.5 * log_inv_rate as f64, + SoundnessType::ConjectureList => num_queries * log_inv_rate as f64, + } + } + + pub fn rbr_soundness_queries_combination( + soundness_type: SoundnessType, + field_size_bits: usize, + num_variables: usize, + log_inv_rate: usize, + log_eta: f64, + ood_samples: usize, + num_queries: usize, + ) -> f64 { + let list_size = Self::list_size_bits(soundness_type, num_variables, log_inv_rate, log_eta); + + let log_combination = ((ood_samples + num_queries) as f64).log2(); + + field_size_bits as f64 - (log_combination + list_size + 1.) + } +} + +impl Display for WhirConfig +where + F: FftField, + MerkleConfig: Config, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt::Display::fmt(&self.mv_parameters, f)?; + writeln!(f, ", folding factor: {:?}", self.folding_factor)?; + writeln!( + f, + "Security level: {} bits using {} security and {} bits of PoW", + self.security_level, self.soundness_type, self.max_pow_bits + )?; + + writeln!( + f, + "initial_folding_pow_bits: {}", + self.starting_folding_pow_bits + )?; + for r in &self.round_parameters { + fmt::Display::fmt(&r, f)?; + } + + writeln!( + f, + "final_queries: {}, final_rate: 2^-{}, final_pow_bits: {}, final_folding_pow_bits: {}", + self.final_queries, + self.final_log_inv_rate, + self.final_pow_bits, + self.final_folding_pow_bits, + )?; + + writeln!(f, "------------------------------------")?; + writeln!(f, "Round by round soundness analysis:")?; + writeln!(f, "------------------------------------")?; + + let field_size_bits = F::field_size_in_bits(); + let log_eta = Self::log_eta(self.soundness_type, self.starting_log_inv_rate); + let mut num_variables = self.mv_parameters.num_variables; + + if self.committment_ood_samples > 0 { + writeln!( + f, + "{:.1} bits -- OOD commitment", + Self::rbr_ood_sample( + self.soundness_type, + num_variables, + self.starting_log_inv_rate, + log_eta, + field_size_bits, + self.committment_ood_samples + ) + )?; + } + + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + self.soundness_type, + field_size_bits, + num_variables, + self.starting_log_inv_rate, + log_eta, + ); + let sumcheck_error = Self::rbr_soundness_fold_sumcheck( + self.soundness_type, + field_size_bits, + num_variables, + self.starting_log_inv_rate, + log_eta, + ); + writeln!( + f, + "{:.1} bits -- (x{:?}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", + prox_gaps_error.min(sumcheck_error) + self.starting_folding_pow_bits, + self.folding_factor, + prox_gaps_error, + sumcheck_error, + self.starting_folding_pow_bits, + )?; + + num_variables -= self.folding_factor.at_round(0); + + for (round, r) in self.round_parameters.iter().enumerate() { + let next_rate = r.log_inv_rate + (self.folding_factor.at_round(round) - 1); + let log_eta = Self::log_eta(self.soundness_type, next_rate); + + if r.ood_samples > 0 { + writeln!( + f, + "{:.1} bits -- OOD sample", + Self::rbr_ood_sample( + self.soundness_type, + num_variables, + next_rate, + log_eta, + field_size_bits, + r.ood_samples + ) + )?; + } + + let query_error = Self::rbr_queries(self.soundness_type, r.log_inv_rate, r.num_queries); + let combination_error = Self::rbr_soundness_queries_combination( + self.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_eta, + r.ood_samples, + r.num_queries, + ); + writeln!( + f, + "{:.1} bits -- query error: {:.1}, combination: {:.1}, pow: {:.1}", + query_error.min(combination_error) + r.pow_bits, + query_error, + combination_error, + r.pow_bits, + )?; + + let prox_gaps_error = Self::rbr_soundness_fold_prox_gaps( + self.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_eta, + ); + let sumcheck_error = Self::rbr_soundness_fold_sumcheck( + self.soundness_type, + field_size_bits, + num_variables, + next_rate, + log_eta, + ); + + writeln!( + f, + "{:.1} bits -- (x{:?}) prox gaps: {:.1}, sumcheck: {:.1}, pow: {:.1}", + prox_gaps_error.min(sumcheck_error) + r.folding_pow_bits, + self.folding_factor, + prox_gaps_error, + sumcheck_error, + r.folding_pow_bits, + )?; + + num_variables -= self.folding_factor.at_round(round + 1); + } + + let query_error = Self::rbr_queries( + self.soundness_type, + self.final_log_inv_rate, + self.final_queries, + ); + writeln!( + f, + "{:.1} bits -- query error: {:.1}, pow: {:.1}", + query_error + self.final_pow_bits, + query_error, + self.final_pow_bits, + )?; + + if self.final_sumcheck_rounds > 0 { + let combination_error = field_size_bits as f64 - 1.; + writeln!( + f, + "{:.1} bits -- (x{}) combination: {:.1}, pow: {:.1}", + combination_error + self.final_pow_bits, + self.final_sumcheck_rounds, + combination_error, + self.final_folding_pow_bits, + )?; + } + + Ok(()) + } +} + +impl Display for RoundConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Num_queries: {}, rate: 2^-{}, pow_bits: {}, ood_samples: {}, folding_pow: {}", + self.num_queries, + self.log_inv_rate, + self.pow_bits, + self.ood_samples, + self.folding_pow_bits, + ) + } +} diff --git a/whir/src/whir/prover.rs b/whir/src/whir/prover.rs new file mode 100644 index 000000000..a1e0bcd8b --- /dev/null +++ b/whir/src/whir/prover.rs @@ -0,0 +1,456 @@ +use super::{Statement, WhirProof, committer::Witness, parameters::WhirConfig}; +use crate::{ + domain::Domain, + ntt::expand_from_coeff, + parameters::FoldType, + poly_utils::{ + MultilinearPoint, + coeffs::CoefficientList, + fold::{compute_fold, restructure_evaluations}, + }, + sumcheck::prover_not_skipping::SumcheckProverNotSkipping, + utils::{self, expand_randomness}, +}; +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use ark_std::{end_timer, start_timer}; +use nimue::{ + ByteChallenges, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldWriter}, +}; +use nimue_pow::{self, PoWChallenge}; + +use crate::whir::fs_utils::{DigestWriter, get_challenge_stir_queries}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub struct Prover(pub WhirConfig) +where + F: FftField, + MerkleConfig: Config; + +impl Prover +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + pub(crate) fn validate_parameters(&self) -> bool { + self.0.mv_parameters.num_variables + == self.0.folding_factor.total_number(self.0.n_rounds()) + self.0.final_sumcheck_rounds + } + + fn validate_statement(&self, statement: &Statement) -> bool { + if statement.points.len() != statement.evaluations.len() { + return false; + } + if !statement + .points + .iter() + .all(|point| point.0.len() == self.0.mv_parameters.num_variables) + { + return false; + } + if !self.0.initial_statement && !statement.points.is_empty() { + return false; + } + true + } + + fn validate_witness(&self, witness: &Witness) -> bool { + assert_eq!(witness.ood_points.len(), witness.ood_answers.len()); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } + witness.polynomial.num_variables() == self.0.mv_parameters.num_variables + } + + pub fn prove( + &self, + merlin: &mut Merlin, + mut statement: Statement, + witness: Witness, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + // If any evaluation point is shorter than the folding factor, pad with 0 in front + for p in statement.points.iter_mut() { + while p.n_variables() < self.0.folding_factor.at_round(0) { + p.0.insert(0, F::ONE); + } + } + + assert!(self.validate_parameters()); + assert!(self.validate_statement(&statement)); + assert!(self.validate_witness(&witness)); + + let timer = start_timer!(|| "Single Prover"); + let initial_claims: Vec<_> = witness + .ood_points + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.0.mv_parameters.num_variables, + ) + }) + .chain(statement.points) + .collect(); + let initial_answers: Vec<_> = witness + .ood_answers + .into_iter() + .chain(statement.evaluations) + .collect(); + + if !self.0.initial_statement { + // It is ensured that if there is no initial statement, the + // number of ood samples is also zero. + assert!( + initial_answers.is_empty(), + "Can not have initial answers without initial statement" + ); + } + + let mut sumcheck_prover = None; + let folding_randomness = if self.0.initial_statement { + // If there is initial statement, then we run the sum-check for + // this initial statement. + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + + sumcheck_prover = Some(SumcheckProverNotSkipping::new( + witness.polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + + sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(0), + self.0.starting_folding_pow_bits, + )? + } else { + // If there is no initial statement, there is no need to run the + // initial rounds of the sum-check, and the verifier directly sends + // the initial folding randomnesses. + let mut folding_randomness = vec![F::ZERO; self.0.folding_factor.at_round(0)]; + merlin.fill_challenge_scalars(&mut folding_randomness)?; + + if self.0.starting_folding_pow_bits > 0. { + merlin.challenge_pow::(self.0.starting_folding_pow_bits)?; + } + MultilinearPoint(folding_randomness) + }; + + let round_state = RoundState { + domain: self.0.starting_domain.clone(), + round: 0, + sumcheck_prover, + folding_randomness, + coefficients: witness.polynomial, + prev_merkle: witness.merkle_tree, + prev_merkle_answers: witness.merkle_leaves, + merkle_proofs: vec![], + }; + + let round_timer = start_timer!(|| "Single Round"); + let result = self.round(merlin, round_state); + end_timer!(round_timer); + + end_timer!(timer); + + result + } + + pub(crate) fn round( + &self, + merlin: &mut Merlin, + mut round_state: RoundState, + ) -> ProofResult> + where + Merlin: FieldChallenges + + ByteChallenges + + FieldWriter + + ByteWriter + + PoWChallenge + + DigestWriter, + { + // Fold the coefficients + let folded_coefficients = round_state + .coefficients + .fold(&round_state.folding_randomness); + + let num_variables = self.0.mv_parameters.num_variables + - self.0.folding_factor.total_number(round_state.round); + // num_variables should match the folded_coefficients here. + assert_eq!(num_variables, folded_coefficients.num_variables()); + + // Base case + if round_state.round == self.0.n_rounds() { + // Directly send coefficients of the polynomial to the verifier. + merlin.add_scalars(folded_coefficients.coeffs())?; + + // Final verifier queries and answers. The indices are over the + // *folded* domain. + let final_challenge_indexes = get_challenge_stir_queries( + round_state.domain.size(), // The size of the *original* domain before folding + self.0.folding_factor.at_round(round_state.round), /* The folding factor we used to fold the previous polynomial */ + self.0.final_queries, + merlin, + )?; + + let merkle_proof = round_state + .prev_merkle + .generate_multi_proof(final_challenge_indexes.clone()) + .unwrap(); + // Every query requires opening these many in the previous Merkle tree + let fold_size = 1 << self.0.folding_factor.at_round(round_state.round); + let answers = final_challenge_indexes + .into_iter() + .map(|i| { + round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec() + }) + .collect(); + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if self.0.final_pow_bits > 0. { + merlin.challenge_pow::(self.0.final_pow_bits)?; + } + + // Final sumcheck + if self.0.final_sumcheck_rounds > 0 { + round_state + .sumcheck_prover + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new(folded_coefficients.clone(), &[], &[], &[]) + }) + .compute_sumcheck_polynomials::( + merlin, + self.0.final_sumcheck_rounds, + self.0.final_folding_pow_bits, + )?; + } + + return Ok(WhirProof(round_state.merkle_proofs)); + } + + let round_params = &self.0.round_parameters[round_state.round]; + + // Fold the coefficients, and compute fft of polynomial (and commit) + let new_domain = round_state.domain.scale(2); + let expansion = new_domain.size() / folded_coefficients.num_coeffs(); + let evals = expand_from_coeff(folded_coefficients.coeffs(), expansion); + // Group the evaluations into leaves by the *next* round folding factor + // TODO: `stack_evaluations` and `restructure_evaluations` are really in-place algorithms. + // They also partially overlap and undo one another. We should merge them. + let folded_evals = utils::stack_evaluations( + evals, + self.0.folding_factor.at_round(round_state.round + 1), // Next round fold factor + ); + let folded_evals = restructure_evaluations( + folded_evals, + self.0.fold_optimisation, + new_domain.backing_domain.group_gen(), + new_domain.backing_domain.group_gen_inv(), + self.0.folding_factor.at_round(round_state.round + 1), + ); + + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact( + 1 << self + .0 + .folding_factor + .get_folding_factor_of_round(round_state.round + 1), + ); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals + .par_chunks_exact(1 << self.0.folding_factor.at_round(round_state.round + 1)); + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + + let root = merkle_tree.root(); + merlin.add_digest(root)?; + + // OOD Samples + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = Vec::with_capacity(round_params.ood_samples); + if round_params.ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers.extend(ood_points.iter().map(|ood_point| { + folded_coefficients.evaluate(&MultilinearPoint::expand_from_univariate( + *ood_point, + num_variables, + )) + })); + merlin.add_scalars(&ood_answers)?; + } + + // STIR queries + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain.size(), // Current domain size *before* folding + self.0.folding_factor.at_round(round_state.round), // Current fold factor + round_params.num_queries, + merlin, + )?; + // Compute the generator of the folded domain, in the extension field + let domain_scaled_gen = round_state + .domain + .backing_domain + .element(1 << self.0.folding_factor.at_round(round_state.round)); + let stir_challenges: Vec<_> = ood_points + .into_iter() + .chain( + stir_challenges_indexes + .iter() + .map(|i| domain_scaled_gen.pow([*i as u64])), + ) + .map(|univariate| MultilinearPoint::expand_from_univariate(univariate, num_variables)) + .collect(); + + let merkle_proof = round_state + .prev_merkle + .generate_multi_proof(stir_challenges_indexes.clone()) + .unwrap(); + let fold_size = 1 << self.0.folding_factor.at_round(round_state.round); + let answers: Vec<_> = stir_challenges_indexes + .iter() + .map(|i| round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) + .collect(); + // Evaluate answers in the folding randomness. + let mut stir_evaluations = ood_answers.clone(); + match self.0.fold_optimisation { + FoldType::Naive => { + // See `Verifier::compute_folds_full` + let domain_size = round_state.domain.backing_domain.size(); + let domain_gen = round_state.domain.backing_domain.element(1); + let domain_gen_inv = domain_gen.inverse().unwrap(); + let coset_domain_size = 1 << self.0.folding_factor.at_round(round_state.round); + // The domain (before folding) is split into cosets of size + // `coset_domain_size` (which is just `fold_size`). Each coset + // is generated by powers of `coset_generator` (which is just the + // `fold_size`-root of unity) multiplied by a different + // `coset_offset`. + // For example, if `fold_size = 16`, and the domain size is N, then + // the domain is (1, w, w^2, ..., w^(N-1)), the domain generator + // is w, and the coset generator is w^(N/16). + // The first coset is (1, w^(N/16), w^(2N/16), ..., w^(15N/16)) + // which is also a subgroup itself (the coset_offset is 1). + // The second coset would be w * , the third coset would be + // w^2 * , and so on. Until w^(N/16-1) * . + let coset_generator_inv = + domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); + stir_evaluations.extend(stir_challenges_indexes.iter().zip(&answers).map( + |(index, answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = domain_gen_inv.pow([*index as u64]); + + // In the Naive mode, the oracle consists directly of the + // evaluations of f over the domain. We leverage an + // algorithm to compute the evaluations of the folded f + // at the corresponding point in folded domain (which is + // coset_offset^fold_size). + compute_fold( + answers, + &round_state.folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + F::from(2).inverse().unwrap(), + self.0.folding_factor.at_round(round_state.round), + ) + }, + )) + } + FoldType::ProverHelps => stir_evaluations.extend(answers.iter().map(|answers| { + // In the ProverHelps mode, the oracle values have been linearly + // transformed such that they are exactly the coefficients of the + // multilinear polynomial whose evaluation at the folding randomness + // is just the folding of f evaluated at the folded point. + CoefficientList::new(answers.to_vec()).evaluate(&round_state.folding_randomness) + })), + } + round_state.merkle_proofs.push((merkle_proof, answers)); + + // PoW + if round_params.pow_bits > 0. { + merlin.challenge_pow::(round_params.pow_bits)?; + } + + // Randomness for combination + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, stir_challenges.len()); + + let mut sumcheck_prover = round_state + .sumcheck_prover + .take() + .map(|mut sumcheck_prover| { + sumcheck_prover.add_new_equality( + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ); + sumcheck_prover + }) + .unwrap_or_else(|| { + SumcheckProverNotSkipping::new( + folded_coefficients.clone(), + &stir_challenges, + &combination_randomness, + &stir_evaluations, + ) + }); + + let folding_randomness = sumcheck_prover + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor.at_round(round_state.round + 1), + round_params.folding_pow_bits, + )?; + + let round_state = RoundState { + round: round_state.round + 1, + domain: new_domain, + sumcheck_prover: Some(sumcheck_prover), + folding_randomness, + coefficients: folded_coefficients, /* TODO: Is this redundant with `sumcheck_prover.coeff` ? */ + prev_merkle: merkle_tree, + prev_merkle_answers: folded_evals, + merkle_proofs: round_state.merkle_proofs, + }; + + self.round(merlin, round_state) + } +} + +pub(crate) struct RoundState +where + F: FftField, + MerkleConfig: Config, +{ + pub(crate) round: usize, + pub(crate) domain: Domain, + pub(crate) sumcheck_prover: Option>, + pub(crate) folding_randomness: MultilinearPoint, + pub(crate) coefficients: CoefficientList, + pub(crate) prev_merkle: MerkleTree, + pub(crate) prev_merkle_answers: Vec, + pub(crate) merkle_proofs: Vec<(MultiPath, Vec>)>, +} diff --git a/whir/src/whir/verifier.rs b/whir/src/whir/verifier.rs new file mode 100644 index 000000000..0cbe366e9 --- /dev/null +++ b/whir/src/whir/verifier.rs @@ -0,0 +1,631 @@ +use std::iter; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; +use nimue::{ + ByteChallenges, ByteReader, ProofError, ProofResult, + plugins::ark::{FieldChallenges, FieldReader}, +}; +use nimue_pow::{self, PoWChallenge}; + +use super::{Statement, WhirProof, parameters::WhirConfig}; +use crate::{ + parameters::FoldType, + poly_utils::{MultilinearPoint, coeffs::CoefficientList, eq_poly_outside, fold::compute_fold}, + sumcheck::proof::SumcheckPolynomial, + utils::expand_randomness, + whir::fs_utils::{DigestReader, get_challenge_stir_queries}, +}; + +pub struct Verifier +where + F: FftField, + MerkleConfig: Config, +{ + pub(crate) params: WhirConfig, + pub(crate) two_inv: F, +} + +#[derive(Clone)] +pub(crate) struct ParsedCommitment { + pub(crate) root: D, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +#[derive(Clone)] +pub(crate) struct ParsedProof { + pub(crate) initial_combination_randomness: Vec, + pub(crate) initial_sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, + pub(crate) rounds: Vec>, + pub(crate) final_domain_gen_inv: F, + pub(crate) final_randomness_indexes: Vec, + pub(crate) final_randomness_points: Vec, + pub(crate) final_randomness_answers: Vec>, + pub(crate) final_folding_randomness: MultilinearPoint, + pub(crate) final_sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, + pub(crate) final_sumcheck_randomness: MultilinearPoint, + pub(crate) final_coefficients: CoefficientList, +} + +#[derive(Debug, Clone)] +pub(crate) struct ParsedRound { + pub(crate) folding_randomness: MultilinearPoint, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, + pub(crate) stir_challenges_indexes: Vec, + pub(crate) stir_challenges_points: Vec, + pub(crate) stir_challenges_answers: Vec>, + pub(crate) combination_randomness: Vec, + pub(crate) sumcheck_rounds: Vec<(SumcheckPolynomial, F)>, + pub(crate) domain_gen_inv: F, +} + +impl Verifier +where + F: FftField, + MerkleConfig: Config, + PowStrategy: nimue_pow::PowStrategy, +{ + pub fn new(params: WhirConfig) -> Self { + Verifier { + params, + two_inv: F::from(2).inverse().unwrap(), // The only inverse in the entire code :) + } + } + + fn parse_commitment( + &self, + arthur: &mut Arthur, + ) -> ProofResult> + where + Arthur: ByteReader + FieldReader + FieldChallenges + DigestReader, + { + let root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples]; + if self.params.committment_ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + Ok(ParsedCommitment { + root, + ood_points, + ood_answers, + }) + } + + fn parse_proof( + &self, + arthur: &mut Arthur, + parsed_commitment: &ParsedCommitment, + statement: &Statement, // Will be needed later + whir_proof: &WhirProof, + ) -> ProofResult> + where + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, + { + let mut sumcheck_rounds = Vec::new(); + let mut folding_randomness: MultilinearPoint; + let initial_combination_randomness; + if self.params.initial_statement { + // Derive combination randomness and first sumcheck polynomial + let [combination_randomness_gen]: [F; 1] = arthur.challenge_scalars()?; + initial_combination_randomness = expand_randomness( + combination_randomness_gen, + parsed_commitment.ood_points.len() + statement.points.len(), + ); + + // Initial sumcheck + sumcheck_rounds.reserve_exact(self.params.folding_factor.at_round(0)); + for _ in 0..self.params.folding_factor.at_round(0) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + } + + folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + } else { + assert_eq!(parsed_commitment.ood_points.len(), 0); + assert_eq!(statement.points.len(), 0); + + initial_combination_randomness = vec![F::ONE]; + + let mut folding_randomness_vec = vec![F::ZERO; self.params.folding_factor.at_round(0)]; + arthur.fill_challenge_scalars(&mut folding_randomness_vec)?; + folding_randomness = MultilinearPoint(folding_randomness_vec); + + // PoW + if self.params.starting_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.starting_folding_pow_bits)?; + } + }; + + let mut prev_root = parsed_commitment.root.clone(); + let mut domain_gen = self.params.starting_domain.backing_domain.group_gen(); + let mut exp_domain_gen = domain_gen.pow([1 << self.params.folding_factor.at_round(0)]); + let mut domain_gen_inv = self.params.starting_domain.backing_domain.group_gen_inv(); + let mut domain_size = self.params.starting_domain.size(); + let mut rounds = vec![]; + + for r in 0..self.params.n_rounds() { + let (merkle_proof, answers) = &whir_proof.0[r]; + let round_params = &self.params.round_parameters[r]; + + let new_root = arthur.read_digest()?; + + let mut ood_points = vec![F::ZERO; round_params.ood_samples]; + let mut ood_answers = vec![F::ZERO; round_params.ood_samples]; + if round_params.ood_samples > 0 { + arthur.fill_challenge_scalars(&mut ood_points)?; + arthur.fill_next_scalars(&mut ood_answers)?; + } + + let stir_challenges_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(r), + round_params.num_queries, + arthur, + )?; + + let stir_challenges_points = stir_challenges_indexes + .iter() + .map(|index| exp_domain_gen.pow([*index as u64])) + .collect(); + + if !merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || merkle_proof.leaf_indexes != stir_challenges_indexes + { + return Err(ProofError::InvalidProof); + } + + if round_params.pow_bits > 0. { + arthur.challenge_pow::(round_params.pow_bits)?; + } + + let [combination_randomness_gen] = arthur.challenge_scalars()?; + let combination_randomness = expand_randomness( + combination_randomness_gen, + stir_challenges_indexes.len() + round_params.ood_samples, + ); + + let mut sumcheck_rounds = + Vec::with_capacity(self.params.folding_factor.at_round(r + 1)); + for _ in 0..self.params.folding_factor.at_round(r + 1) { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if round_params.folding_pow_bits > 0. { + arthur.challenge_pow::(round_params.folding_pow_bits)?; + } + } + + let new_folding_randomness = + MultilinearPoint(sumcheck_rounds.iter().map(|&(_, r)| r).rev().collect()); + + rounds.push(ParsedRound { + folding_randomness, + ood_points, + ood_answers, + stir_challenges_indexes, + stir_challenges_points, + stir_challenges_answers: answers.to_vec(), + combination_randomness, + sumcheck_rounds, + domain_gen_inv, + }); + + folding_randomness = new_folding_randomness; + + prev_root = new_root.clone(); + domain_gen = domain_gen * domain_gen; + exp_domain_gen = domain_gen.pow([1 << self.params.folding_factor.at_round(r + 1)]); + domain_gen_inv = domain_gen_inv * domain_gen_inv; + domain_size /= 2; + } + + let mut final_coefficients = vec![F::ZERO; 1 << self.params.final_sumcheck_rounds]; + arthur.fill_next_scalars(&mut final_coefficients)?; + let final_coefficients = CoefficientList::new(final_coefficients); + + // Final queries verify + let final_randomness_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor.at_round(self.params.n_rounds()), + self.params.final_queries, + arthur, + )?; + let final_randomness_points = final_randomness_indexes + .iter() + .map(|index| exp_domain_gen.pow([*index as u64])) + .collect(); + + let (final_merkle_proof, final_randomness_answers) = &whir_proof.0[whir_proof.0.len() - 1]; + if !final_merkle_proof + .verify( + &self.params.leaf_hash_params, + &self.params.two_to_one_params, + &prev_root, + final_randomness_answers.iter().map(|a| a.as_ref()), + ) + .unwrap() + || final_merkle_proof.leaf_indexes != final_randomness_indexes + { + return Err(ProofError::InvalidProof); + } + + if self.params.final_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_pow_bits)?; + } + + let mut final_sumcheck_rounds = Vec::with_capacity(self.params.final_sumcheck_rounds); + for _ in 0..self.params.final_sumcheck_rounds { + let sumcheck_poly_evals: [F; 3] = arthur.next_scalars()?; + let sumcheck_poly = SumcheckPolynomial::new(sumcheck_poly_evals.to_vec(), 1); + let [folding_randomness_single] = arthur.challenge_scalars()?; + final_sumcheck_rounds.push((sumcheck_poly, folding_randomness_single)); + + if self.params.final_folding_pow_bits > 0. { + arthur.challenge_pow::(self.params.final_folding_pow_bits)?; + } + } + let final_sumcheck_randomness = MultilinearPoint( + final_sumcheck_rounds + .iter() + .map(|&(_, r)| r) + .rev() + .collect(), + ); + + Ok(ParsedProof { + initial_combination_randomness, + initial_sumcheck_rounds: sumcheck_rounds, + rounds, + final_domain_gen_inv: domain_gen_inv, + final_folding_randomness: folding_randomness, + final_randomness_indexes, + final_randomness_points, + final_randomness_answers: final_randomness_answers.to_vec(), + final_sumcheck_rounds, + final_sumcheck_randomness, + final_coefficients, + }) + } + + fn compute_v_poly( + &self, + parsed_commitment: &ParsedCommitment, + statement: &Statement, + proof: &ParsedProof, + ) -> F { + let mut num_variables = self.params.mv_parameters.num_variables; + + let mut folding_randomness = MultilinearPoint( + iter::once(&proof.final_sumcheck_randomness.0) + .chain(iter::once(&proof.final_folding_randomness.0)) + .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) + .flatten() + .copied() + .collect(), + ); + + let statement_points: Vec> = statement + .points + .clone() + .into_iter() + .map(|mut p| { + while p.n_variables() < self.params.folding_factor.at_round(0) { + p.0.insert(0, F::ONE); + } + p + }) + .collect(); + let mut value = parsed_commitment + .ood_points + .iter() + .map(|ood_point| MultilinearPoint::expand_from_univariate(*ood_point, num_variables)) + .chain(statement_points) + .zip(&proof.initial_combination_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(&point, &folding_randomness)) + .sum(); + + for (round, round_proof) in proof.rounds.iter().enumerate() { + num_variables -= self.params.folding_factor.at_round(round); + folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); + + let ood_points = &round_proof.ood_points; + let stir_challenges_points = &round_proof.stir_challenges_points; + let stir_challenges: Vec<_> = ood_points + .iter() + .chain(stir_challenges_points) + .cloned() + .map(|univariate| { + MultilinearPoint::expand_from_univariate(univariate, num_variables) + // TODO: + // Maybe refactor outside + }) + .collect(); + + let sum_of_claims: F = stir_challenges + .into_iter() + .map(|point| eq_poly_outside(&point, &folding_randomness)) + .zip(&round_proof.combination_randomness) + .map(|(point, rand)| point * rand) + .sum(); + + value += sum_of_claims; + } + + value + } + + pub(crate) fn compute_folds(&self, parsed: &ParsedProof) -> Vec> { + match self.params.fold_optimisation { + FoldType::Naive => self.compute_folds_full(parsed), + FoldType::ProverHelps => self.compute_folds_helped(parsed), + } + } + + fn compute_folds_full(&self, parsed: &ParsedProof) -> Vec> { + let mut domain_size = self.params.starting_domain.backing_domain.size(); + + let mut result = Vec::new(); + + for (round_index, round) in parsed.rounds.iter().enumerate() { + let coset_domain_size = 1 << self.params.folding_factor.at_round(round_index); + // This is such that coset_generator^coset_domain_size = F::ONE + // let _coset_generator = domain_gen.pow(&[(domain_size / coset_domain_size) as u64]); + let coset_generator_inv = round + .domain_gen_inv + .pow([(domain_size / coset_domain_size) as u64]); + + let evaluations: Vec<_> = round + .stir_challenges_indexes + .iter() + .zip(&round.stir_challenges_answers) + .map(|(index, answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = round.domain_gen_inv.pow([*index as u64]); + + compute_fold( + answers, + &round.folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + self.two_inv, + self.params.folding_factor.at_round(round_index), + ) + }) + .collect(); + result.push(evaluations); + domain_size /= 2; + } + + let coset_domain_size = 1 << self.params.folding_factor.at_round(parsed.rounds.len()); + let domain_gen_inv = parsed.final_domain_gen_inv; + + // Final round + let coset_generator_inv = domain_gen_inv.pow([(domain_size / coset_domain_size) as u64]); + let evaluations: Vec<_> = parsed + .final_randomness_indexes + .iter() + .zip(&parsed.final_randomness_answers) + .map(|(index, answers)| { + // The coset is w^index * + // let _coset_offset = domain_gen.pow(&[*index as u64]); + let coset_offset_inv = domain_gen_inv.pow([*index as u64]); + + compute_fold( + answers, + &parsed.final_folding_randomness.0, + coset_offset_inv, + coset_generator_inv, + self.two_inv, + self.params.folding_factor.at_round(parsed.rounds.len()), + ) + }) + .collect(); + result.push(evaluations); + + result + } + + fn compute_folds_helped(&self, parsed: &ParsedProof) -> Vec> { + let mut result = Vec::new(); + + for round in &parsed.rounds { + let evaluations: Vec<_> = round + .stir_challenges_answers + .iter() + .map(|answers| { + CoefficientList::new(answers.to_vec()).evaluate(&round.folding_randomness) + }) + .collect(); + result.push(evaluations); + } + + // Final round + let evaluations: Vec<_> = parsed + .final_randomness_answers + .iter() + .map(|answers| { + CoefficientList::new(answers.to_vec()).evaluate(&parsed.final_folding_randomness) + }) + .collect(); + result.push(evaluations); + + result + } + + pub fn verify( + &self, + arthur: &mut Arthur, + statement: &Statement, + whir_proof: &WhirProof, + ) -> ProofResult + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment(arthur)?; + let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof)?; + + let computed_folds = self.compute_folds(&parsed); + + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != parsed_commitment + .ood_answers + .iter() + .copied() + .chain(statement.evaluations.clone()) + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() + { + return Err(ProofError::InvalidProof); + } + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); + } + + for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { + let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); + + let values = round.ood_answers.iter().copied().chain(folds.clone()); + + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + + values + .zip(&round.combination_randomness) + .map(|(val, rand)| val * rand) + .sum::(); + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + // Check the foldings computed from the proof match the evaluations of the polynomial + let final_folds = &computed_folds[computed_folds.len() - 1]; + let final_evaluations = parsed + .final_coefficients + .evaluate_at_univariate(&parsed.final_randomness_points); + if !final_folds + .iter() + .zip(final_evaluations) + .all(|(&fold, eval)| fold == eval) + { + return Err(ProofError::InvalidProof); + } + + // Check the final sumchecks + if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); + let claimed_sum = prev_sumcheck_poly_eval; + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + + // Check the final sumcheck evaluation + let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, statement, &parsed); + + if prev_sumcheck_poly_eval + != evaluation_of_v_poly + * parsed + .final_coefficients + .evaluate(&parsed.final_sumcheck_randomness) + { + return Err(ProofError::InvalidProof); + } + + Ok(parsed_commitment.root) + } +}