From 0cf3e299787de6b131752a398945418e0e6be436 Mon Sep 17 00:00:00 2001 From: Robert Remen Date: Tue, 5 Mar 2024 10:59:19 +0100 Subject: [PATCH] implement GpuProofConfig abstraction (#37) This PR implements a GpuProofConfig abstraction to remove a dependency on CSReferenceAssembly structure. Instantiation of the CSReferenceAssembly structure is expensive so we need to enable the possibility to supply the pieces of configuration the prover needs in an alternative way. This abstraction makes it possible to generate the information either from the CSReferenceAssembly structure or directly from a Circuit structure. - [x] PR title corresponds to the body of PR (we generate changelog entries from PRs). - [x] Tests for the changes have been added / updated. - [x] Documentation comments have been added / updated. - [x] Code has been formatted via `cargo fmt` and linted via `cargo check`. --- Cargo.lock | 150 +++++++++++++++---------------- Cargo.toml | 2 +- src/constraint_evaluation.rs | 57 ++++-------- src/data_structures/cache.rs | 35 ++++---- src/gpu_proof_config.rs | 166 +++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/prover.rs | 115 +++++++++++------------- src/synthesis_utils.rs | 59 +++++++------ src/test.rs | 51 +++++------ 9 files changed, 380 insertions(+), 256 deletions(-) create mode 100644 src/gpu_proof_config.rs diff --git a/Cargo.lock b/Cargo.lock index d23179d..2e2a12d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,9 +24,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "arr_macro" @@ -145,7 +145,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.48", + "syn 2.0.52", "which", ] @@ -262,7 +262,7 @@ checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" [[package]] name = "boojum" version = "0.2.0" -source = "git+https://github.com/matter-labs/era-boojum?branch=main#03888f0fbf810a18e98d010dd11891fe32097352" +source = "git+https://github.com/matter-labs/era-boojum?branch=main#30300f043c9afaeeb35d0f7bd3cc0acaf69ccde4" dependencies = [ "arrayvec 0.7.4", "bincode", @@ -323,12 +323,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.83" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] +checksum = "a0ba8f7aaa012f30d5b2861462f6708eccd49c3c39863fe083a308035f63d723" [[package]] name = "cexpr" @@ -467,7 +464,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" dependencies = [ - "crossbeam-channel 0.5.11", + "crossbeam-channel 0.5.12", "crossbeam-deque 0.8.5", "crossbeam-epoch 0.9.18", "crossbeam-queue 0.3.11", @@ -486,9 +483,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" dependencies = [ "crossbeam-utils 0.8.19", ] @@ -616,7 +613,7 @@ dependencies = [ [[package]] name = "cs_derive" version = "0.1.0" -source = "git+https://github.com/matter-labs/era-boojum?branch=main#03888f0fbf810a18e98d010dd11891fe32097352" +source = "git+https://github.com/matter-labs/era-boojum?branch=main#30300f043c9afaeeb35d0f7bd3cc0acaf69ccde4" dependencies = [ "proc-macro-error", "proc-macro2 1.0.78", @@ -752,7 +749,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys", ] [[package]] @@ -1013,9 +1010,9 @@ checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" [[package]] name = "hermit-abi" -version = "0.3.6" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5256b483761cd23699d0da46cc6fd2ee3be420bbe6d020ae4a091e70b7e9fd" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -1038,7 +1035,7 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "windows-sys 0.52.0", + "windows-sys", ] [[package]] @@ -1091,9 +1088,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.3" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -1166,12 +1163,12 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +checksum = "2caa5afb8bf9f3a2652760ce7d4f62d21c4d5a423e68466fca30df82f2330164" dependencies = [ "cfg-if 1.0.0", - "windows-sys 0.48.0", + "windows-targets 0.52.4", ] [[package]] @@ -1198,9 +1195,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "maybe-uninit" @@ -1382,7 +1379,7 @@ dependencies = [ "proc-macro-crate 1.3.1", "proc-macro2 1.0.78", "quote 1.0.35", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -1393,9 +1390,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "opaque-debug" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "p256" @@ -1543,7 +1540,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2 1.0.78", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -1720,9 +1717,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" dependencies = [ "either", "rayon-core", @@ -1770,9 +1767,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -1851,14 +1848,14 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys", ] [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "scopeguard" @@ -1888,29 +1885,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2 1.0.78", "quote 1.0.35", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -1939,7 +1936,7 @@ checksum = "b93fb4adc70021ac1b47f7d45e8cc4169baaa7ea58483bc5b721d19a26202212" dependencies = [ "proc-macro2 1.0.78", "quote 1.0.35", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -1999,7 +1996,7 @@ dependencies = [ [[package]] name = "shivini" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "blake2 0.10.6", @@ -2113,9 +2110,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" dependencies = [ "proc-macro2 1.0.78", "quote 1.0.35", @@ -2158,7 +2155,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.2.3", + "indexmap 2.2.5", "toml_datetime", "winnow", ] @@ -2169,7 +2166,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" dependencies = [ - "indexmap 2.2.3", + "indexmap 2.2.5", "toml_datetime", "winnow", ] @@ -2278,22 +2275,13 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - [[package]] name = "windows-sys" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] @@ -2313,17 +2301,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -2334,9 +2322,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -2346,9 +2334,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -2358,9 +2346,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -2370,9 +2358,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -2382,9 +2370,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -2394,9 +2382,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -2406,9 +2394,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "winnow" diff --git a/Cargo.toml b/Cargo.toml index abb89fc..c614e18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shivini" -version = "0.1.0" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/constraint_evaluation.rs b/src/constraint_evaluation.rs index d8ad7cd..ade8473 100644 --- a/src/constraint_evaluation.rs +++ b/src/constraint_evaluation.rs @@ -1,31 +1,20 @@ -use boojum::{ - config::CSConfig, - cs::{ - gates::lookup_marker::LookupFormalGate, - implementations::{reference_cs::CSReferenceAssembly, setup::TreeNode}, - traits::{evaluator::PerChunkOffset, gate::GatePlacementStrategy}, - }, +use crate::gpu_proof_config::GpuProofConfig; +use boojum::cs::{ + gates::lookup_marker::LookupFormalGate, + implementations::setup::TreeNode, + traits::{evaluator::PerChunkOffset, gate::GatePlacementStrategy}, }; use super::*; -pub fn get_evaluators_of_general_purpose_cols< - P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, - CFG: CSConfig, ->( - cs: &CSReferenceAssembly, +pub fn get_evaluators_of_general_purpose_cols( + config: &GpuProofConfig, selectors_placement: &TreeNode, ) -> Vec { let mut gates = vec![]; - for (evaluator_idx, (evaluator, _gate_type_id)) in cs - .evaluation_data_over_general_purpose_columns + for (evaluator_idx, evaluator) in config .evaluators_over_general_purpose_columns .iter() - .zip( - cs.evaluation_data_over_general_purpose_columns - .gate_type_ids_for_general_purpose_columns - .iter(), - ) .enumerate() { if evaluator.debug_name @@ -72,34 +61,21 @@ pub fn get_evaluators_of_general_purpose_cols< gates } -pub fn get_specialized_evaluators_from_assembly< - P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, - CFG: CSConfig, ->( - cs: &CSReferenceAssembly, +pub fn get_specialized_evaluators_from_assembly( + config: &GpuProofConfig, selectors_placement: &TreeNode, ) -> Vec { - if cs - .evaluation_data_over_specialized_columns - .evaluators_over_specialized_columns - .len() - < 1 - { + if config.evaluators_over_specialized_columns.len() < 1 { return vec![]; } let (_deg, _constants_for_gates_over_general_purpose_columns) = selectors_placement.compute_stats(); let mut gates = vec![]; - for (idx, (evaluator, gate_type_id)) in cs - .evaluation_data_over_specialized_columns + for (idx, (evaluator, gate_type_id)) in config .evaluators_over_specialized_columns .iter() - .zip( - cs.evaluation_data_over_specialized_columns - .gate_type_ids_for_specialized_columns - .iter(), - ) + .zip(config.gate_type_ids_for_specialized_columns.iter()) .enumerate() { if evaluator.debug_name @@ -120,7 +96,7 @@ pub fn get_specialized_evaluators_from_assembly< ); let num_terms = evaluator.num_quotient_terms; - let placement_strategy = cs + let placement_strategy = config .placement_strategies .get(&gate_type_id) .copied() @@ -136,9 +112,8 @@ pub fn get_specialized_evaluators_from_assembly< let total_terms = num_terms * num_repetitions; - let (initial_offset, per_repetition_offset, total_constants_available) = cs - .evaluation_data_over_specialized_columns - .offsets_for_specialized_evaluators[idx]; + let (initial_offset, per_repetition_offset, total_constants_available) = + config.offsets_for_specialized_evaluators[idx]; let _placement_data = ( num_repetitions, diff --git a/src/data_structures/cache.rs b/src/data_structures/cache.rs index 5cb9fd8..d26a33b 100644 --- a/src/data_structures/cache.rs +++ b/src/data_structures/cache.rs @@ -1,12 +1,9 @@ -use boojum::config::ProvingCSConfig; use boojum::cs::implementations::pow::PoWRunner; use boojum::cs::implementations::prover::ProofConfig; -use boojum::cs::implementations::reference_cs::CSReferenceAssembly; use boojum::cs::implementations::transcript::Transcript; -use boojum::cs::implementations::verifier::VerificationKey; +use boojum::cs::implementations::verifier::{VerificationKey, VerificationKeyCircuitGeometry}; use boojum::cs::implementations::witness::WitnessVec; use boojum::cs::oracle::TreeHasher; -use boojum::field::traits::field_like::PrimeFieldLikeVectorized; use boojum::worker::Worker; use cudart_sys::CudaError::ErrorMemoryAllocation; use std::collections::BTreeMap; @@ -48,6 +45,7 @@ impl StorageCacheStrategy { } use crate::cs::GpuSetup; +use crate::gpu_proof_config::GpuProofConfig; use crate::prover::{ compute_quotient_degree, gpu_prove_from_external_witness_data_with_cache_strategy, }; @@ -496,13 +494,12 @@ pub(crate) struct CacheStrategy { impl CacheStrategy { pub(crate) fn get< - P: PrimeFieldLikeVectorized, TR: Transcript, H: TreeHasher, POW: PoWRunner, A: GoodAllocator, >( - cs: &CSReferenceAssembly, + config: &GpuProofConfig, external_witness_data: &WitnessVec, proof_config: ProofConfig, setup: &GpuSetup, @@ -515,13 +512,14 @@ impl CacheStrategy { println!("reusing cache strategy"); Ok(*strategy) } else { - let strategies = Self::get_strategy_candidates(cs, &proof_config, setup); + let strategies = + Self::get_strategy_candidates(config, &proof_config, setup, &vk.fixed_parameters); for (_, strategy) in strategies.iter().copied() { _setup_cache_reset(); dry_run_start(); let result = - gpu_prove_from_external_witness_data_with_cache_strategy::( - cs, + gpu_prove_from_external_witness_data_with_cache_strategy::( + config, external_witness_data, proof_config.clone(), setup, @@ -548,27 +546,30 @@ impl CacheStrategy { } } - pub(crate) fn get_strategy_candidates< - P: PrimeFieldLikeVectorized, - A: GoodAllocator, - >( - cs: &CSReferenceAssembly, + pub(crate) fn get_strategy_candidates( + config: &GpuProofConfig, proof_config: &ProofConfig, setup: &GpuSetup, + geometry: &VerificationKeyCircuitGeometry, ) -> Vec<((usize, usize), CacheStrategy)> { let fri_lde_degree = proof_config.fri_lde_factor; - let quotient_degree = compute_quotient_degree(&cs, &setup.selectors_placement); + let quotient_degree = compute_quotient_degree(&config, &setup.selectors_placement); let used_lde_degree = usize::max(quotient_degree, fri_lde_degree); let setup_layout = setup.layout; + let domain_size = geometry.domain_size as usize; + let lookup_parameters = geometry.lookup_parameters; + let total_tables_len = geometry.total_tables_len as usize; + let num_multiplicity_cols = + lookup_parameters.num_multipicities_polys(total_tables_len, domain_size); let trace_layout = TraceLayout { num_variable_cols: setup.variables_hint.len(), num_witness_cols: setup.witnesses_hint.len(), - num_multiplicity_cols: cs.num_multipicities_polys(), + num_multiplicity_cols, }; let arguments_layout = ArgumentsLayout::from_trace_layout_and_lookup_params( trace_layout, quotient_degree, - cs.lookup_parameters.clone(), + geometry.lookup_parameters, ); let setup_num_polys = setup_layout.num_polys(); let trace_num_polys = trace_layout.num_polys(); diff --git a/src/gpu_proof_config.rs b/src/gpu_proof_config.rs new file mode 100644 index 0000000..87e3323 --- /dev/null +++ b/src/gpu_proof_config.rs @@ -0,0 +1,166 @@ +use crate::synthesis_utils::{ + get_verifier_for_base_layer_circuit, get_verifier_for_recursive_layer_circuit, +}; +use boojum::config::ProvingCSConfig; +use boojum::cs::implementations::reference_cs::CSReferenceAssembly; +use boojum::cs::implementations::verifier::{ + TypeErasedGateEvaluationVerificationFunction, Verifier, +}; +use boojum::cs::traits::evaluator::{ + GatePlacementType, PerChunkOffset, TypeErasedGateEvaluationFunction, +}; +use boojum::cs::traits::gate::GatePlacementStrategy; +use boojum::field::goldilocks::{GoldilocksExt2, GoldilocksField}; +use boojum::field::traits::field_like::PrimeFieldLikeVectorized; +use boojum::field::FieldExtension; +use circuit_definitions::aux_definitions::witness_oracle::VmWitnessOracle; +use circuit_definitions::circuit_definitions::base_layer::ZkSyncBaseLayerCircuit; +use circuit_definitions::circuit_definitions::recursion_layer::ZkSyncRecursiveLayerCircuit; +use circuit_definitions::ZkSyncDefaultRoundFunction; +use std::any::TypeId; +use std::collections::HashMap; + +type F = GoldilocksField; +type EXT = GoldilocksExt2; +type BaseLayerCircuit = ZkSyncBaseLayerCircuit, ZkSyncDefaultRoundFunction>; + +pub(crate) struct EvaluatorData { + pub debug_name: String, + pub unique_name: String, + pub max_constraint_degree: usize, + pub num_quotient_terms: usize, + pub total_quotient_terms_over_all_repetitions: usize, + pub num_repetitions_on_row: usize, + pub placement_type: GatePlacementType, +} + +impl> From<&TypeErasedGateEvaluationFunction> + for EvaluatorData +{ + fn from(value: &TypeErasedGateEvaluationFunction) -> Self { + let debug_name = value.debug_name.clone(); + let unique_name = value.unique_name.clone(); + let max_constraint_degree = value.max_constraint_degree; + let num_quotient_terms = value.num_quotient_terms; + let total_quotient_terms_over_all_repetitions = + value.total_quotient_terms_over_all_repetitions; + let num_repetitions_on_row = value.num_repetitions_on_row; + let placement_type = value.placement_type; + Self { + debug_name, + unique_name, + max_constraint_degree, + num_quotient_terms, + total_quotient_terms_over_all_repetitions, + num_repetitions_on_row, + placement_type, + } + } +} + +impl> + From<&TypeErasedGateEvaluationVerificationFunction> for EvaluatorData +{ + fn from(value: &TypeErasedGateEvaluationVerificationFunction) -> Self { + let debug_name = value.debug_name.clone(); + let unique_name = value.unique_name.clone(); + let max_constraint_degree = value.max_constraint_degree; + let num_quotient_terms = value.num_quotient_terms; + let total_quotient_terms_over_all_repetitions = + value.total_quotient_terms_over_all_repetitions; + let num_repetitions_on_row = value.num_repetitions_on_row; + let placement_type = value.placement_type; + Self { + debug_name, + unique_name, + max_constraint_degree, + num_quotient_terms, + total_quotient_terms_over_all_repetitions, + num_repetitions_on_row, + placement_type, + } + } +} + +pub struct GpuProofConfig { + pub(crate) gate_type_ids_for_specialized_columns: Vec, + pub(crate) evaluators_over_specialized_columns: Vec, + pub(crate) offsets_for_specialized_evaluators: Vec<(PerChunkOffset, PerChunkOffset, usize)>, + pub(crate) evaluators_over_general_purpose_columns: Vec, + pub(crate) placement_strategies: HashMap, +} + +impl GpuProofConfig { + pub fn from_assembly>( + cs: &CSReferenceAssembly, + ) -> Self { + let evaluation_data_over_specialized_columns = &cs.evaluation_data_over_specialized_columns; + let gate_type_ids_for_specialized_columns = evaluation_data_over_specialized_columns + .gate_type_ids_for_specialized_columns + .clone(); + let evaluators_over_specialized_columns = evaluation_data_over_specialized_columns + .evaluators_over_specialized_columns + .iter() + .map(|x| x.into()) + .collect(); + let evaluators_over_general_purpose_columns = cs + .evaluation_data_over_general_purpose_columns + .evaluators_over_general_purpose_columns + .iter() + .map(|x| x.into()) + .collect(); + let offsets_for_specialized_evaluators = evaluation_data_over_specialized_columns + .offsets_for_specialized_evaluators + .clone(); + let placement_strategies = cs.placement_strategies.clone(); + Self { + gate_type_ids_for_specialized_columns, + evaluators_over_specialized_columns, + offsets_for_specialized_evaluators, + evaluators_over_general_purpose_columns, + placement_strategies, + } + } + + pub fn from_verifier(verifier: &Verifier) -> Self { + let gate_type_ids_for_specialized_columns = + verifier.gate_type_ids_for_specialized_columns.clone(); + let evaluators_over_specialized_columns = verifier + .evaluators_over_specialized_columns + .iter() + .map(|x| x.into()) + .collect(); + let offsets_for_specialized_evaluators = + verifier.offsets_for_specialized_evaluators.clone(); + let evaluators_over_general_purpose_columns = verifier + .evaluators_over_general_purpose_columns + .iter() + .map(|x| x.into()) + .collect(); + let placement_strategies = verifier.placement_strategies.clone(); + Self { + gate_type_ids_for_specialized_columns, + evaluators_over_specialized_columns, + offsets_for_specialized_evaluators, + evaluators_over_general_purpose_columns, + placement_strategies, + } + } + + pub fn from_base_layer_circuit(circuit: &BaseLayerCircuit) -> Self { + Self::from_verifier(&get_verifier_for_base_layer_circuit(circuit)) + } + + pub fn from_recursive_layer_circuit(circuit: &ZkSyncRecursiveLayerCircuit) -> Self { + Self::from_verifier(&get_verifier_for_recursive_layer_circuit(circuit)) + } + + #[cfg(test)] + pub(crate) fn from_circuit_wrapper(wrapper: &crate::synthesis_utils::CircuitWrapper) -> Self { + use crate::synthesis_utils::CircuitWrapper::*; + match wrapper { + Base(circuit) => Self::from_base_layer_circuit(circuit), + Recursive(circuit) => Self::from_recursive_layer_circuit(circuit), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 77e989b..8284ad0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,6 +42,7 @@ use copy_permutation::*; use data_structures::*; use lookup::*; use poly::*; +pub mod gpu_proof_config; mod prover; mod quotient; #[cfg(feature = "zksync")] diff --git a/src/prover.rs b/src/prover.rs index b138520..031bd3d 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -2,7 +2,6 @@ use std::alloc::Global; use std::rc::Rc; use boojum::{ - config::ProvingCSConfig, cs::{ gates::lookup_marker::LookupFormalGate, implementations::{ @@ -10,7 +9,6 @@ use boojum::{ pow::{NoPow, PoWRunner}, proof::{OracleQuery, Proof, SingleRoundQueries}, prover::ProofConfig, - reference_cs::CSReferenceAssembly, setup::TreeNode, transcript::Transcript, utils::{domain_generator_for_size, materialize_powers_serial}, @@ -26,6 +24,7 @@ use boojum::{ }; use crate::cs::GpuSetup; +use crate::gpu_proof_config::GpuProofConfig; use crate::{ arith::{deep_quotient_except_public_inputs, deep_quotient_public_input}, cs::PACKED_PLACEHOLDER_BITMASK, @@ -34,14 +33,12 @@ use crate::{ use super::*; pub fn gpu_prove_from_external_witness_data< - P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, TR: Transcript, H: TreeHasher, POW: PoWRunner, A: GoodAllocator, >( - // service layer reuses assembly for repeated proving, so that just borrow it - cs: &CSReferenceAssembly, + config: &GpuProofConfig, external_witness_data: &WitnessVec, // TODO: read data from Assembly pinned storage proof_config: ProofConfig, setup: &GpuSetup, @@ -49,8 +46,8 @@ pub fn gpu_prove_from_external_witness_data< transcript_params: TR::TransciptParameters, worker: &Worker, ) -> CudaResult> { - let cache_strategy = CacheStrategy::get::( - cs, + let cache_strategy = CacheStrategy::get::( + config, external_witness_data, proof_config.clone(), setup, @@ -58,8 +55,8 @@ pub fn gpu_prove_from_external_witness_data< transcript_params.clone(), worker, )?; - gpu_prove_from_external_witness_data_with_cache_strategy::( - cs, + gpu_prove_from_external_witness_data_with_cache_strategy::( + config, external_witness_data, proof_config, setup, @@ -71,14 +68,12 @@ pub fn gpu_prove_from_external_witness_data< } pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< - P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, TR: Transcript, H: TreeHasher, POW: PoWRunner, A: GoodAllocator, >( - // service layer reuses assembly for repeated proving, so that just borrow it - cs: &CSReferenceAssembly, + config: &GpuProofConfig, external_witness_data: &WitnessVec, // TODO: read data from Assembly pinned storage proof_config: ProofConfig, setup: &GpuSetup, @@ -89,11 +84,6 @@ pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< ) -> CudaResult> { let mut timer = std::time::Instant::now(); let result = { - assert_eq!( - cs.next_available_place_idx(), - 0, - "CS should be empty and hold no data" - ); assert!( is_prover_context_initialized(), "prover context should be initialized" @@ -101,10 +91,14 @@ pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< let num_variable_cols = setup.variables_hint.len(); let num_witness_cols = setup.witnesses_hint.len(); - let num_multiplicity_cols = cs.num_multipicities_polys(); - let domain_size = cs.max_trace_len; + let geometry = vk.fixed_parameters.clone(); + let domain_size = geometry.domain_size as usize; + let lookup_parameters = geometry.lookup_parameters; + let total_tables_len = geometry.total_tables_len as usize; + let num_multiplicity_cols = + lookup_parameters.num_multipicities_polys(total_tables_len, domain_size); let fri_lde_degree = proof_config.fri_lde_factor; - let quotient_degree = compute_quotient_degree(&cs, &setup.selectors_placement); + let quotient_degree = compute_quotient_degree(config, &setup.selectors_placement); let used_lde_degree = usize::max(quotient_degree, fri_lde_degree); let cap_size = setup.setup_tree.cap_size; let setup_cache = SetupCache::new_from_gpu_setup( @@ -135,7 +129,7 @@ pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< let arguments_layout = ArgumentsLayout::from_trace_layout_and_lookup_params( trace_layout, quotient_degree, - cs.lookup_parameters.clone(), + lookup_parameters, ); let mut arguments_cache = ArgumentsCache::new( cache_strategy.arguments, @@ -156,7 +150,7 @@ pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< &setup_cache.aux.variable_indexes, &setup_cache.aux.witness_indexes, &external_witness_data, - &cs.lookup_parameters, + &lookup_parameters, worker, trace_evaluations_storage, )?; @@ -190,8 +184,8 @@ pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< let value = external_witness_data.all_values[variable_idx]; public_inputs_with_locations.push((col, row, value)); } - gpu_prove_from_trace::<_, TR, _, NoPow, _>( - cs, + gpu_prove_from_trace::( + config, public_inputs_with_locations, setup, setup_cache, @@ -211,12 +205,7 @@ pub(crate) fn gpu_prove_from_external_witness_data_with_cache_strategy< result } -pub fn compute_quotient_degree< - P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, ->( - cs: &CSReferenceAssembly, - selectors_placement: &TreeNode, -) -> usize { +pub fn compute_quotient_degree(config: &GpuProofConfig, selectors_placement: &TreeNode) -> usize { let (max_constraint_contribution_degree, _number_of_constant_polys) = selectors_placement.compute_stats(); @@ -227,8 +216,7 @@ pub fn compute_quotient_degree< 0 }; - let max_degree_from_specialized_gates = cs - .evaluation_data_over_specialized_columns + let max_degree_from_specialized_gates = config .evaluators_over_specialized_columns .iter() .map(|el| el.max_constraint_degree - 1) @@ -250,13 +238,12 @@ pub fn compute_quotient_degree< } fn gpu_prove_from_trace< - P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, TR: Transcript, H: TreeHasher, POW: PoWRunner, A: GoodAllocator, >( - cs: &CSReferenceAssembly, + config: &GpuProofConfig, public_inputs_with_locations: Vec<(usize, usize, F)>, setup_base: &GpuSetup, setup_cache: &mut SetupCache, @@ -269,9 +256,16 @@ fn gpu_prove_from_trace< transcript_params: TR::TransciptParameters, worker: &Worker, ) -> CudaResult> { + let geometry = vk.fixed_parameters.clone(); + let domain_size = geometry.domain_size as usize; + let lookup_parameters = geometry.lookup_parameters; + let total_tables_len = geometry.total_tables_len as usize; + let num_multiplicity_cols = + lookup_parameters.num_multipicities_polys(total_tables_len, domain_size); + assert!(domain_size.is_power_of_two()); + assert_eq!(setup_evaluations.domain_size, domain_size); assert!(proof_config.fri_lde_factor.is_power_of_two()); assert!(proof_config.fri_lde_factor > 1); - assert!(cs.max_trace_len.is_power_of_two()); let cap_size = proof_config.merkle_tree_cap_size; assert!(cap_size > 0); @@ -285,10 +279,8 @@ fn gpu_prove_from_trace< } let selectors_placement = setup_base.selectors_placement.clone(); - let domain_size = cs.max_trace_len; // counted in elements of P - assert_eq!(setup_base.constant_columns[0].len(), domain_size); - let quotient_degree = compute_quotient_degree(&cs, &selectors_placement); + let quotient_degree = compute_quotient_degree(config, &selectors_placement); let fri_lde_degree = proof_config.fri_lde_factor; let used_lde_degree = std::cmp::max(fri_lde_degree, quotient_degree); @@ -367,7 +359,7 @@ fn gpu_prove_from_trace< let num_intermediate_partial_product_relations = argument_raw_storage.as_polynomials().partial_products.len(); - let (lookup_beta, powers_of_gamma_for_lookup) = if cs.lookup_parameters + let (lookup_beta, powers_of_gamma_for_lookup) = if geometry.lookup_parameters != LookupParameters::NoLookup { let h_lookup_beta = if is_dry_run()? { @@ -391,16 +383,16 @@ fn gpu_prove_from_trace< .output_placement(lookup_evaluator_id) .expect("lookup gate must be placed"); - let variables_offset = cs.parameters.num_columns_under_copy_permutation; + let variables_offset = geometry.parameters.num_columns_under_copy_permutation; #[allow(unreachable_code)] - let powers_of_gamma = match cs.lookup_parameters { + let powers_of_gamma = match geometry.lookup_parameters { LookupParameters::NoLookup => { unreachable!() } LookupParameters::TableIdAsConstant { .. } | LookupParameters::TableIdAsVariable { .. } => { - let columns_per_subargument = cs.lookup_parameters.columns_per_subargument(); + let columns_per_subargument = geometry.lookup_parameters.columns_per_subargument(); let mut h_powers_of_gamma = vec![]; let mut current = EF::ONE; @@ -432,14 +424,12 @@ fn gpu_prove_from_trace< | a @ LookupParameters::UseSpecializedColumnsWithTableIdAsConstant { width, .. } => { // ensure proper setup assert_eq!( - cs.evaluation_data_over_specialized_columns - .gate_type_ids_for_specialized_columns[0], + config.gate_type_ids_for_specialized_columns[0], std::any::TypeId::of::(), "we expect first specialized gate to be the lookup -" ); - let (initial_offset, offset_per_repetition, _) = cs - .evaluation_data_over_specialized_columns - .offsets_for_specialized_evaluators[0]; + let (initial_offset, offset_per_repetition, _) = + config.offsets_for_specialized_evaluators[0]; assert_eq!(initial_offset.constants_offset, 0); if let LookupParameters::UseSpecializedColumnsWithTableIdAsConstant { @@ -469,7 +459,7 @@ fn gpu_prove_from_trace< &lookup_beta, &powers_of_gamma, variables_offset, - cs.lookup_parameters, + lookup_parameters, &mut argument_raw_storage, )?; @@ -497,21 +487,17 @@ fn gpu_prove_from_trace< let h_alpha = ExtensionField::::from_coeff_in_base(h_alpha); let _alpha: DExt = h_alpha.into(); - let num_lookup_subarguments = cs.num_sublookup_arguments(); - let num_multiplicities_polys = cs.num_multipicities_polys(); + let num_lookup_subarguments = + lookup_parameters.num_sublookup_arguments_for_geometry(&geometry.parameters); + let num_multiplicities_polys = num_multiplicity_cols; let total_num_lookup_argument_terms = num_lookup_subarguments + num_multiplicities_polys; - let total_num_gate_terms_for_specialized_columns = cs - .evaluation_data_over_specialized_columns + let total_num_gate_terms_for_specialized_columns = config .evaluators_over_specialized_columns .iter() - .zip( - cs.evaluation_data_over_specialized_columns - .gate_type_ids_for_specialized_columns - .iter(), - ) + .zip(config.gate_type_ids_for_specialized_columns.iter()) .map(|(evaluator, gate_type_id)| { - let placement_strategy = cs + let placement_strategy = config .placement_strategies .get(gate_type_id) .copied() @@ -529,8 +515,7 @@ fn gpu_prove_from_trace< }) .sum(); - let total_num_gate_terms_for_general_purpose_columns: usize = cs - .evaluation_data_over_general_purpose_columns + let total_num_gate_terms_for_general_purpose_columns: usize = config .evaluators_over_general_purpose_columns .iter() .map(|evaluator| evaluator.total_quotient_terms_over_all_repetitions) @@ -570,13 +555,13 @@ fn gpu_prove_from_trace< let mut quotient = ComplexPoly::::empty(quotient_degree * domain_size)?; - let variables_offset = cs.parameters.num_columns_under_copy_permutation; + let variables_offset = geometry.parameters.num_columns_under_copy_permutation; let general_purpose_gates = - get_evaluators_of_general_purpose_cols(&cs, &setup_base.selectors_placement); + get_evaluators_of_general_purpose_cols(config, &setup_base.selectors_placement); let specialized_gates = - get_specialized_evaluators_from_assembly(&cs, &setup_base.selectors_placement); + get_specialized_evaluators_from_assembly(config, &setup_base.selectors_placement); let num_cols_per_product = quotient_degree; let specialized_cols_challenge_power_offset = total_num_lookup_argument_terms; let general_purpose_cols_challenge_power_offset = @@ -587,7 +572,7 @@ fn gpu_prove_from_trace< trace_cache, setup_cache, arguments_cache, - cs.lookup_parameters, + geometry.lookup_parameters, &setup_base.table_ids_column_idxes, &setup_base.selectors_placement, &specialized_gates, @@ -897,7 +882,7 @@ fn gpu_prove_from_trace< quotient_holder.num_polys_in_base(), domain_size, proof_config, - cs.lookup_parameters.clone(), + geometry.lookup_parameters, num_queries, query_details_for_cosets.clone(), query_idx_and_coset_idx_map, diff --git a/src/synthesis_utils.rs b/src/synthesis_utils.rs index 2a291d8..0ff4680 100644 --- a/src/synthesis_utils.rs +++ b/src/synthesis_utils.rs @@ -8,7 +8,7 @@ use boojum::cs::implementations::proof::Proof; use boojum::cs::implementations::prover::ProofConfig; use boojum::cs::implementations::reference_cs::{CSReferenceAssembly, CSReferenceImplementation}; use boojum::cs::implementations::setup::FinalizationHintsForProver; -use boojum::cs::implementations::verifier::VerificationKey; +use boojum::cs::implementations::verifier::{VerificationKey, Verifier}; use boojum::cs::traits::GoodAllocator; use boojum::cs::{CSGeometry, GateConfigurationHolder, StaticToolboxHolder}; use boojum::field::goldilocks::{GoldilocksExt2, GoldilocksField}; @@ -35,7 +35,7 @@ type EXT = GoldilocksExt2; #[derive(Clone, serde::Serialize, serde::Deserialize)] pub(crate) enum CircuitWrapper { - Base(ZkSyncBaseLayerCircuit, ZkSyncDefaultRoundFunction>), + Base(BaseLayerCircuit), Recursive(ZkSyncRecursiveLayerCircuit), } @@ -73,15 +73,15 @@ impl CircuitWrapper { pub fn into_base_layer(self) -> BaseLayerCircuit { match self { CircuitWrapper::Base(inner) => inner, - CircuitWrapper::Recursive(_) => unimplemented!(), + _ => unimplemented!(), } } #[allow(dead_code)] pub fn into_recursive_layer(self) -> ZkSyncRecursiveLayerCircuit { match self { - CircuitWrapper::Base(_) => unimplemented!(), CircuitWrapper::Recursive(inner) => inner, + _ => unimplemented!(), } } @@ -89,24 +89,21 @@ impl CircuitWrapper { pub fn as_base_layer(&self) -> &BaseLayerCircuit { match self { CircuitWrapper::Base(inner) => inner, - CircuitWrapper::Recursive(_) => unimplemented!(), + _ => unimplemented!(), } } #[allow(dead_code)] pub fn as_recursive_layer(&self) -> &ZkSyncRecursiveLayerCircuit { match self { - CircuitWrapper::Base(_) => unimplemented!(), CircuitWrapper::Recursive(inner) => inner, + _ => unimplemented!(), } } #[allow(dead_code)] pub fn is_base_layer(&self) -> bool { - match self { - CircuitWrapper::Base(_) => true, - CircuitWrapper::Recursive(_) => false, - } + matches!(self, CircuitWrapper::Base(_)) } #[allow(dead_code)] @@ -123,24 +120,32 @@ impl CircuitWrapper { vk: &VerificationKey, proof: &ZksyncProof, ) -> bool { - let verifier = match self { - CircuitWrapper::Base(_base_circuit) => { - use circuit_definitions::circuit_definitions::verifier_builder::dyn_verifier_builder_for_circuit_type; - - let verifier_builder = - dyn_verifier_builder_for_circuit_type::( - self.numeric_circuit_type(), - ); - verifier_builder.create_verifier() - } - CircuitWrapper::Recursive(recursive_circuit) => { - let verifier_builder = recursive_circuit.into_dyn_verifier_builder(); - verifier_builder.create_verifier() - } - }; - + let verifier = self.get_verifier(); verifier.verify::((), vk, proof) } + + pub(crate) fn get_verifier(&self) -> Verifier { + match self { + CircuitWrapper::Base(inner) => get_verifier_for_base_layer_circuit(inner), + CircuitWrapper::Recursive(inner) => get_verifier_for_recursive_layer_circuit(inner), + } + } +} + +pub(crate) fn get_verifier_for_base_layer_circuit(circuit: &BaseLayerCircuit) -> Verifier { + use circuit_definitions::circuit_definitions::verifier_builder::dyn_verifier_builder_for_circuit_type; + let verifier_builder = + dyn_verifier_builder_for_circuit_type::( + circuit.numeric_circuit_type(), + ); + verifier_builder.create_verifier() +} + +pub(crate) fn get_verifier_for_recursive_layer_circuit( + circuit: &ZkSyncRecursiveLayerCircuit, +) -> Verifier { + let verifier_builder = circuit.into_dyn_verifier_builder(); + verifier_builder.create_verifier() } #[allow(dead_code)] @@ -195,7 +200,9 @@ pub(crate) fn init_cs_for_external_proving( // in init_or_synthesize_assembly, we expect CFG to be either // ProvingCSConfig or SetupCSConfig pub trait AllowInitOrSynthesize: CSConfig {} + impl AllowInitOrSynthesize for ProvingCSConfig {} + impl AllowInitOrSynthesize for SetupCSConfig {} pub(crate) fn init_or_synthesize_assembly( diff --git a/src/test.rs b/src/test.rs index 3db3400..261ce23 100644 --- a/src/test.rs +++ b/src/test.rs @@ -43,6 +43,7 @@ use boojum::field::traits::field_like::PrimeFieldLikeVectorized; #[allow(dead_code)] pub type DefaultDevCS = CSReferenceAssembly; type P = F; +use crate::gpu_proof_config::GpuProofConfig; use serial_test::serial; #[serial] @@ -80,14 +81,14 @@ fn test_proof_comparison_for_poseidon_gate_with_private_witnesses() { ProvingCSConfig, false, >(finalization_hint.as_ref()); + let config = GpuProofConfig::from_assembly(&reusable_cs); let proof = gpu_prove_from_external_witness_data::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, prover_config.clone(), &gpu_setup, @@ -335,7 +336,7 @@ fn test_dry_runs() { let witness = proving_cs.witness.unwrap(); let (reusable_cs, _) = init_or_synth_cs_for_sha256::(finalization_hint.as_ref()); - + let config = GpuProofConfig::from_assembly(&reusable_cs); let worker = Worker::new(); let prover_config = init_proof_cfg(); let (setup_base, _setup, vk, setup_tree, vars_hint, wits_hint) = setup_cs.get_full_setup( @@ -355,18 +356,21 @@ fn test_dry_runs() { .unwrap(); assert!(domain_size.is_power_of_two()); - let candidates = - CacheStrategy::get_strategy_candidates(&reusable_cs, &prover_config, &gpu_setup); + let candidates = CacheStrategy::get_strategy_candidates( + &config, + &prover_config, + &gpu_setup, + &vk.fixed_parameters, + ); for (_, strategy) in candidates.iter().copied() { let proof = || { let _ = crate::prover::gpu_prove_from_external_witness_data_with_cache_strategy::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, prover_config.clone(), &gpu_setup, @@ -439,14 +443,15 @@ fn test_proof_comparison_for_sha256() { let (reusable_cs, _) = init_or_synth_cs_for_sha256::( finalization_hint.as_ref(), ); + let config = GpuProofConfig::from_assembly(&reusable_cs); + let proof = gpu_prove_from_external_witness_data::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, prover_config.clone(), &gpu_setup, @@ -802,8 +807,8 @@ mod zksync { const DEFAULT_CIRCUIT_INPUT: &str = "default.circuit"; use crate::synthesis_utils::{ - init_cs_for_external_proving, init_or_synthesize_assembly, synth_circuit_for_proving, - synth_circuit_for_setup, CircuitWrapper, + init_or_synthesize_assembly, synth_circuit_for_proving, synth_circuit_for_setup, + CircuitWrapper, }; #[allow(dead_code)] @@ -978,16 +983,14 @@ mod zksync { let gpu_proof = { let proving_cs = synth_circuit_for_proving(circuit.clone(), &finalization_hint); let witness = proving_cs.witness.unwrap(); - let reusable_cs = - init_cs_for_external_proving(circuit.clone(), &finalization_hint); + let config = GpuProofConfig::from_circuit_wrapper(&circuit); let proof = gpu_prove_from_external_witness_data::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, proof_config.clone(), &gpu_setup, @@ -1128,15 +1131,14 @@ mod zksync { println!("gpu proving"); let gpu_proof = { let witness = proving_cs.witness.as_ref().unwrap(); - let reusable_cs = init_cs_for_external_proving(circuit.clone(), &finalization_hint); + let config = GpuProofConfig::from_circuit_wrapper(&circuit); gpu_prove_from_external_witness_data::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, proof_cfg.clone(), &gpu_setup, @@ -1191,7 +1193,7 @@ mod zksync { ); let proving_cs = synth_circuit_for_proving(circuit.clone(), &finalization_hint); let witness = proving_cs.witness.unwrap(); - let reusable_cs = init_cs_for_external_proving(circuit.clone(), &finalization_hint); + let config = GpuProofConfig::from_circuit_wrapper(&circuit); let gpu_setup = { let _ctx = ProverContext::create().expect("gpu prover context"); GpuSetup::::from_setup_and_hints( @@ -1205,13 +1207,12 @@ mod zksync { }; let proof_fn = || { let _ = gpu_prove_from_external_witness_data::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, proof_cfg.clone(), &gpu_setup, @@ -1230,8 +1231,8 @@ mod zksync { // but nice for peace of mind _setup_cache_reset(); let strategy = - CacheStrategy::get::<_, DefaultTranscript, DefaultTreeHasher, NoPow, Global>( - &reusable_cs, + CacheStrategy::get::( + &config, &witness, proof_cfg.clone(), &gpu_setup, @@ -1290,6 +1291,7 @@ mod zksync { let (reusable_cs, _) = init_or_synth_cs_for_sha256::( finalization_hint.as_ref(), ); + let config = GpuProofConfig::from_assembly(&reusable_cs); let mut gpu_setup = GpuSetup::::from_setup_and_hints( setup_base.clone(), clone_reference_tree(&setup_tree), @@ -1301,13 +1303,12 @@ mod zksync { witness.public_inputs_locations = vec![(0, 0)]; gpu_setup.variables_hint[0][0] = PACKED_PLACEHOLDER_BITMASK; let _ = gpu_prove_from_external_witness_data::< - _, DefaultTranscript, DefaultTreeHasher, NoPow, Global, >( - &reusable_cs, + &config, &witness, proof_config.clone(), &gpu_setup,