diff --git a/Cargo.lock b/Cargo.lock index cc8fe18eb0..0352f5283c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3789,6 +3789,7 @@ dependencies = [ "num-bigint", "rand", "rand_xorshift", + "recursion", "serde", "serde_derive", "serde_json", @@ -3889,6 +3890,21 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "recursion" +version = "0.12.0" +dependencies = [ + "aggregator", + "halo2_proofs", + "halo2curves", + "itertools 0.11.0", + "log", + "rand", + "serde_json", + "snark-verifier 0.1.8", + "snark-verifier-sdk 0.1.8", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index af3efe362a..208e5ef751 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "aggregator", "prover", "compression", + "recursion", ] resolver = "2" diff --git a/aggregator/src/lib.rs b/aggregator/src/lib.rs index 08e43882b7..cbe17d5e12 100644 --- a/aggregator/src/lib.rs +++ b/aggregator/src/lib.rs @@ -6,8 +6,6 @@ mod aggregation; mod batch; /// blob struct and constants mod blob; -/// Config to recursive aggregate multiple aggregations -mod recursion; // This module implements `Chunk` related data types. // A chunk is a list of blocks. mod chunk; @@ -31,7 +29,6 @@ pub use chunk::ChunkInfo; pub use constants::MAX_AGG_SNARKS; pub(crate) use constants::*; pub use param::*; -pub use recursion::*; mod mock_chunk; pub use mock_chunk::MockChunkCircuit; diff --git a/aggregator/src/recursion/circuit.rs b/aggregator/src/recursion/circuit.rs deleted file mode 100644 index 38a3c5ba0c..0000000000 --- a/aggregator/src/recursion/circuit.rs +++ /dev/null @@ -1,577 +0,0 @@ -#![allow(clippy::type_complexity)] -use std::{fs::File, iter, marker::PhantomData, rc::Rc}; - -use halo2_proofs::{ - circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, - poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, -}; -use snark_verifier::{ - loader::halo2::{halo2_ecc::halo2_base as sv_halo2_base, EccInstructions, IntegerInstructions}, - pcs::{ - kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, - AccumulationScheme, AccumulationSchemeProver, - }, - util::{ - arithmetic::{fe_to_fe, fe_to_limbs}, - hash, - }, -}; -use snark_verifier_sdk::{ - types::{Halo2Loader, Plonk}, - SnarkWitness, -}; -use sv_halo2_base::{ - gates::GateInstructions, halo2_proofs, AssignedValue, Context, ContextParams, - QuantumCell::Existing, -}; - -use crate::param::ConfigParams as RecursionCircuitConfigParams; - -use super::*; - -/// Convenience type to represent the verifying key. -type Svk = KzgSuccinctVerifyingKey; - -/// Convenience type to represent the polynomial commitment scheme. -type Pcs = Kzg; - -/// Convenience type to represent the accumulation scheme for accumulating proofs from multiple -/// SNARKs. -type As = KzgAs; - -/// Select condition ? LHS : RHS. -fn select_accumulator<'a>( - loader: &Rc>, - condition: &AssignedValue, - lhs: &KzgAccumulator>>, - rhs: &KzgAccumulator>>, -) -> Result>>, Error> { - let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] - .iter() - .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) - .map(|(lhs, rhs)| { - loader - .ecc_chip() - .select(&mut loader.ctx_mut(), lhs, rhs, condition) - }) - .collect::>() - .try_into() - .unwrap(); - Ok(KzgAccumulator::new( - loader.ec_point_from_assigned(lhs), - loader.ec_point_from_assigned(rhs), - )) -} - -/// Accumulate a value into the current accumulator. -fn accumulate<'a>( - loader: &Rc>, - accumulators: Vec>>>, - as_proof: Value<&'_ [u8]>, -) -> KzgAccumulator>> { - let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); - let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); - As::verify(&Default::default(), &accumulators, &proof).unwrap() -} - -#[derive(Clone)] -pub struct RecursionCircuit { - /// The verifying key for the circuit. - svk: Svk, - /// The default accumulator to initialise the circuit. - default_accumulator: KzgAccumulator, - /// The SNARK witness from the k-th BatchCircuit. - app: SnarkWitness, - /// The SNARK witness from the previous RecursionCircuit, i.e. RecursionCircuit up to the (k-1)-th BatchCircuit. - previous: SnarkWitness, - /// The recursion round, starting at round=0 and incrementing at every subsequent recursion. - round: usize, - /// The public inputs to the RecursionCircuit itself. - instances: Vec, - /// The accumulation of the SNARK proofs recursed over thus far. - as_proof: Value>, - - _marker: PhantomData, -} - -impl RecursionCircuit { - /// The index of the preprocessed digest in the [`RecursionCircuit`]'s instances. Note that we - /// need a single cell to hold this value as it is a poseidon hash over the bn256 curve, hence - /// it fits within an [`Fr`] cell. - /// - /// [`Fr`]: halo2_proofs::halo2curves::bn256::Fr - const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; - - /// The index within the instances to find the "initial" state in the state transition. - const INITIAL_STATE_ROW: usize = Self::PREPROCESSED_DIGEST_ROW + 1; - - /// Construct a new instance of the [`RecursionCircuit`] given the SNARKs from the current and - /// previous [`BatchCircuit`], and the recursion round. - /// - /// [`BatchCircuit`]: aggregator::BatchCircuit - pub fn new( - params: &ParamsKZG, - app: Snark, - previous: Snark, - rng: impl Rng + Send, - round: usize, - ) -> Self { - let svk = params.get_g()[0].into(); - let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); - - let succinct_verify = |snark: &Snark| { - let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); - let proof = - Plonk::::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript); - Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof) - }; - - let accumulators = iter::empty() - .chain(succinct_verify(&app)) - .chain( - (round > 0) - .then(|| succinct_verify(&previous)) - .unwrap_or_else(|| { - let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); - vec![default_accumulator.clone(); num_accumulator] - }), - ) - .collect_vec(); - - let (accumulator, as_proof) = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); - let accumulator = - As::create_proof(&Default::default(), &accumulators, &mut transcript, rng).unwrap(); - (accumulator, transcript.finalize()) - }; - - let init_instances = if round > 0 { - // pick from prev snark - Vec::from( - &previous.instances[0][Self::INITIAL_STATE_ROW - ..Self::INITIAL_STATE_ROW + ST::num_transition_instance()], - ) - } else { - // pick from app - ST::state_prev_indices() - .into_iter() - .map(|i| app.instances[0][i]) - .collect::>() - }; - - let state_instances = ST::state_indices() - .into_iter() - .map(|i| &app.instances[0][i]) - .chain( - ST::additional_indices() - .into_iter() - .map(|i| &app.instances[0][i]), - ); - - let preprocessed_digest = { - let inputs = previous - .protocol - .preprocessed - .iter() - .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) - .map(fe_to_fe) - .chain(previous.protocol.transcript_initial_state) - .collect_vec(); - let mut hasher = hash::Poseidon::from_spec(&NativeLoader, POSEIDON_SPEC.clone()); - hasher.update(&inputs); - hasher.squeeze() - }; - - let instances = [ - accumulator.lhs.x, - accumulator.lhs.y, - accumulator.rhs.x, - accumulator.rhs.y, - ] - .into_iter() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .chain(iter::once(preprocessed_digest)) - .chain(init_instances) - .chain(state_instances.copied()) - .chain(iter::once(Fr::from(round as u64))) - .collect(); - - log::debug!("recursive instance: {:#?}", instances); - - Self { - svk, - default_accumulator, - app: app.into(), - previous: previous.into(), - round, - instances, - as_proof: Value::known(as_proof), - _marker: Default::default(), - } - } - - fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) - } - - fn load_default_accumulator<'a>( - &self, - loader: &Rc>, - ) -> Result>>, Error> { - let [lhs, rhs] = - [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { - let assigned = loader - .ecc_chip() - .assign_constant(&mut loader.ctx_mut(), default) - .unwrap(); - loader.ec_point_from_assigned(assigned) - }); - Ok(KzgAccumulator::new(lhs, rhs)) - } - - /// Returns the number of instance cells in the Recursion Circuit, help to refine the CircuitExt trait - pub fn num_instance_fixed() -> usize { - // [ - // ..lhs (accumulator LHS), - // ..rhs (accumulator RHS), - // preprocessed_digest, - // initial_state, - // state, - // round - // ] - 4 * LIMBS + 2 * ST::num_transition_instance() + ST::num_additional_instance() + 2 - } -} - -impl Circuit for RecursionCircuit { - type Config = config::RecursionConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = (); - - fn without_witnesses(&self) -> Self { - Self { - svk: self.svk, - default_accumulator: self.default_accumulator.clone(), - app: self.app.without_witnesses(), - previous: self.previous.without_witnesses(), - round: self.round, - instances: self.instances.clone(), - as_proof: Value::unknown(), - _marker: Default::default(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = std::env::var("BUNDLE_CONFIG") - .unwrap_or_else(|_| "configs/bundle_circuit.config".to_owned()); - let params: RecursionCircuitConfigParams = serde_json::from_reader( - File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")), - ) - .unwrap(); - - Self::Config::configure(meta, params) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.range().load_lookup_table(&mut layouter)?; - let max_rows = config.range().gate.max_rows; - let main_gate = config.gate(); - - let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner - let assigned_instances = layouter.assign_region( - || "recursion circuit", - |region| -> Result, Error> { - if first_pass { - first_pass = false; - return Ok(vec![]); - } - let mut ctx = Context::new( - region, - ContextParams { - max_rows, - num_context_ids: 1, - fixed_columns: config.base_field_config.range.gate.constants.clone(), - }, - ); - - // The index of the "initial state", i.e. the state last finalised on L1. - let index_init_state = Self::INITIAL_STATE_ROW; - // The index of the "state", i.e. the state achieved post the current batch. - let index_state = index_init_state + ST::num_transition_instance(); - // The index where the "additional" fields required to define the state are - // present. The first field in the "additional" fields is the chain ID. - let index_additional_state = index_state + ST::num_transition_instance(); - // The index to find the "round" of recursion in the current instance of the - // Recursion Circuit. - let index_round = index_additional_state + ST::num_additional_instance(); - - log::debug!( - "indices within instances: init {} |cur {} | add {} | round {}", - index_init_state, - index_state, - index_additional_state, - index_round, - ); - - // Get the field elements representing the "preprocessed digest" and "recursion round". - let [preprocessed_digest, round] = [ - self.instances[Self::PREPROCESSED_DIGEST_ROW], - self.instances[index_round], - ] - .map(|instance| { - main_gate - .assign_integer(&mut ctx, Value::known(instance)) - .unwrap() - }); - - // Get the field elements representing the "initial state" - let initial_state = self.instances[index_init_state..index_state] - .iter() - .map(|&instance| { - main_gate - .assign_integer(&mut ctx, Value::known(instance)) - .unwrap() - }) - .collect::>(); - - // Get the field elements representing the "state" post batch. This includes the - // additional state fields as well. - let state = self.instances[index_state..index_round] - .iter() - .map(|&instance| { - main_gate - .assign_integer(&mut ctx, Value::known(instance)) - .unwrap() - }) - .collect::>(); - - // Whether or not we are in the first round of recursion. - let first_round = main_gate.is_zero(&mut ctx, &round); - let not_first_round = main_gate.not(&mut ctx, Existing(first_round)); - - let loader = Halo2Loader::new(config.ecc_chip(), ctx); - let (mut app_instances, app_accumulators) = - dynamic_verify::(&self.svk, &loader, &self.app, None); - let (mut previous_instances, previous_accumulators) = dynamic_verify::( - &self.svk, - &loader, - &self.previous, - Some(preprocessed_digest), - ); - - // Choose between the default accumulator or the previous accumulator depending on - // whether or not we are in the first round of recursion. - let default_accumulator = self.load_default_accumulator(&loader)?; - let previous_accumulators = previous_accumulators - .iter() - .map(|previous_accumulator| { - select_accumulator( - &loader, - &first_round, - &default_accumulator, - previous_accumulator, - ) - }) - .collect::, Error>>()?; - - // Accumulate the accumulators over the previous accumulators, to compute the - // accumulator values for this instance of the Recursion Circuit. - let KzgAccumulator { lhs, rhs } = accumulate( - &loader, - [app_accumulators, previous_accumulators].concat(), - self.as_proof(), - ); - - let lhs = lhs.into_assigned(); - let rhs = rhs.into_assigned(); - let app_instances = app_instances.pop().unwrap(); - let previous_instances = previous_instances.pop().unwrap(); - - let mut ctx = loader.ctx_mut(); - - ////////////////////////////////////////////////////////////////////////////////// - /////////////////////////////// CONSTRAINTS ////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////// - - // Propagate the "initial state" - let initial_state_propagate = initial_state - .iter() - .zip_eq(previous_instances[index_init_state..index_state].iter()) - .zip_eq( - ST::state_prev_indices() - .into_iter() - .map(|i| &app_instances[i]), - ) - .flat_map(|((&st, &previous_st), &app_inst)| { - [ - // Verify initial_state is same as the first application snark in the - // first round of recursion. - ( - "initial state equal to app's initial (first round)", - main_gate.mul(&mut ctx, Existing(st), Existing(first_round)), - main_gate.mul(&mut ctx, Existing(app_inst), Existing(first_round)), - ), - // Propagate initial_state for subsequent rounds of recursion. - ( - "initial state equal to prev_recursion's initial (not first round)", - main_gate.mul(&mut ctx, Existing(st), Existing(not_first_round)), - previous_st, - ), - ] - }) - .collect::>(); - - // Verify that the current "state" is the same as the state defined in the - // application SNARK. - let verify_app_state = state - .iter() - .zip_eq( - ST::state_indices() - .into_iter() - .map(|i| &app_instances[i]) - .chain( - ST::additional_indices() - .into_iter() - .map(|i| &app_instances[i]), - ), - ) - .map(|(&st, &app_inst)| ("passing cur state to app", st, app_inst)) - .collect::>(); - - // Pick additional inst part in "previous state", verify the items at the front - // is currently propagated to the app inst which is marked as "propagated" - let propagate_app_states = previous_instances[index_additional_state..index_round] - .iter() - .zip( - ST::propagate_indices() - .into_iter() - .map(|i| &app_instances[i]), - ) - .map(|(&st, &app_propagated_inst)| { - ( - "propagate additional states in app (not first round)", - main_gate.mul( - &mut ctx, - Existing(app_propagated_inst), - Existing(not_first_round), - ), - st, - ) - }) - .collect::>(); - - // Verify that the "previous state" (additional state not included) is the same - // as the previous state defined in the current application SNARK. This check is - // meaningful only in subsequent recursion rounds after the first round. - let verify_app_init_state = previous_instances[index_state..index_additional_state] - .iter() - .zip_eq( - ST::state_prev_indices() - .into_iter() - .map(|i| &app_instances[i]), - ) - .map(|(&st, &app_inst)| { - ( - "chain prev state with cur init state (not first round)", - main_gate.mul(&mut ctx, Existing(app_inst), Existing(not_first_round)), - st, - ) - }) - .collect::>(); - - // Finally apply the equality constraints between the (LHS, RHS) values constructed - // above. - for (comment, lhs, rhs) in [ - // Propagate the preprocessed digest. - ( - "propagate preprocessed digest", - main_gate.mul( - &mut ctx, - Existing(preprocessed_digest), - Existing(not_first_round), - ), - previous_instances[Self::PREPROCESSED_DIGEST_ROW], - ), - // Verify that "round" increments by 1 when not the first round of recursion. - ( - "increment recursion round", - round, - main_gate.add( - &mut ctx, - Existing(not_first_round), - Existing(previous_instances[index_round]), - ), - ), - ] - .into_iter() - .chain(initial_state_propagate) - .chain(verify_app_state) - .chain(verify_app_init_state) - .chain(propagate_app_states) - { - use halo2_proofs::dev::unwrap_value; - debug_assert_eq!( - unwrap_value(lhs.value()), - unwrap_value(rhs.value()), - "equality constraint fail: {}", - comment - ); - ctx.region.constrain_equal(lhs.cell(), rhs.cell())?; - } - - // Mark the end of this phase. - config.base_field_config.finalize(&mut ctx); - - #[cfg(feature = "display")] - dbg!(ctx.total_advice); - #[cfg(feature = "display")] - println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1); - - // Return the computed instance cells for this Recursion Circuit. - Ok([lhs.x(), lhs.y(), rhs.x(), rhs.y()] - .into_iter() - .flat_map(|coordinate| coordinate.limbs()) - .chain(iter::once(&preprocessed_digest)) - .chain(initial_state.iter()) - .chain(state.iter()) - .chain(iter::once(&round)) - .map(|assigned| assigned.cell()) - .collect()) - }, - )?; - - assert_eq!(assigned_instances.len(), self.num_instance()[0]); - - // Ensure that the computed instances are in fact the instances for this circuit. - for (row, limb) in assigned_instances.into_iter().enumerate() { - layouter.constrain_instance(limb, config.instance, row)?; - } - - Ok(()) - } -} - -impl CircuitExt for RecursionCircuit { - fn num_instance(&self) -> Vec { - vec![Self::num_instance_fixed()] - } - - fn instances(&self) -> Vec> { - vec![self.instances.clone()] - } - - fn accumulator_indices() -> Option> { - Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) - } - - fn selectors(config: &Self::Config) -> Vec { - config.base_field_config.range.gate.basic_gates[0] - .iter() - .map(|gate| gate.q_enable) - .collect() - } -} diff --git a/aggregator/src/recursion/config.rs b/aggregator/src/recursion/config.rs deleted file mode 100644 index 68f55ce55e..0000000000 --- a/aggregator/src/recursion/config.rs +++ /dev/null @@ -1,64 +0,0 @@ -use halo2_proofs::plonk::{Column, Instance}; -use snark_verifier::loader::halo2::halo2_ecc::{ - ecc::{BaseFieldEccChip, EccChip}, - fields::fp::FpConfig, - halo2_base::gates::{flex_gate::FlexGateConfig, range::RangeConfig}, -}; - -use crate::param::ConfigParams as RecursionCircuitConfigParams; - -use super::*; - -#[derive(Clone)] -pub struct RecursionConfig { - /// The non-native field arithmetic config from halo2-lib. - pub base_field_config: FpConfig, - /// The single instance column to hold the public input to the [`RecursionCircuit`]. - pub instance: Column, -} - -impl RecursionConfig { - pub fn configure( - meta: &mut ConstraintSystem, - params: RecursionCircuitConfigParams, - ) -> Self { - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = FpConfig::configure( - meta, - params.strategy, - ¶ms.num_advice, - ¶ms.num_lookup_advice, - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - halo2_base::utils::modulus::(), - 0, - params.degree as usize, - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self { - base_field_config, - instance, - } - } - - pub fn gate(&self) -> &FlexGateConfig { - &self.base_field_config.range.gate - } - - pub fn range(&self) -> &RangeConfig { - &self.base_field_config.range - } - - pub fn ecc_chip(&self) -> BaseFieldEccChip { - EccChip::construct(self.base_field_config.clone()) - } -} diff --git a/aggregator/src/tests.rs b/aggregator/src/tests.rs index 93632136a7..a0dfc90e5a 100644 --- a/aggregator/src/tests.rs +++ b/aggregator/src/tests.rs @@ -1,6 +1,5 @@ mod aggregation; mod blob; -mod recursion; mod rlc; #[macro_export] diff --git a/prover/Cargo.toml b/prover/Cargo.toml index abe7ba558e..e8e7eb4885 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -12,6 +12,7 @@ bus-mapping = { path = "../bus-mapping", default-features = false } compression = { path = "../compression" } eth-types = { path = "../eth-types" } mpt-zktrie = { path = "../zktrie" } +recursion = { path = "../recursion" } zkevm-circuits = { path = "../zkevm-circuits", default-features = false } snark-verifier.workspace = true diff --git a/prover/src/common/prover/recursion.rs b/prover/src/common/prover/recursion.rs index 6046a85838..e9ef740760 100644 --- a/prover/src/common/prover/recursion.rs +++ b/prover/src/common/prover/recursion.rs @@ -1,8 +1,9 @@ use std::env; -use aggregator::{initial_recursion_snark, RecursionCircuit, StateTransition, MAX_AGG_SNARKS}; +use aggregator::MAX_AGG_SNARKS; use anyhow::Result; use rand::Rng; +use recursion::{initial_recursion_snark, RecursionCircuit, StateTransition}; use snark_verifier_sdk::{gen_snark_shplonk, Snark}; use crate::{ diff --git a/prover/src/recursion.rs b/prover/src/recursion.rs index 8e1694a56e..8b95281851 100644 --- a/prover/src/recursion.rs +++ b/prover/src/recursion.rs @@ -1,6 +1,7 @@ use halo2_proofs::halo2curves::bn256::Fr; -use aggregator::{BatchCircuit, StateTransition}; +use aggregator::BatchCircuit; +use recursion::StateTransition; use snark_verifier_sdk::Snark; /// 4 fields for 2 hashes (Hi, Lo) diff --git a/recursion/Cargo.toml b/recursion/Cargo.toml new file mode 100644 index 0000000000..61d48119da --- /dev/null +++ b/recursion/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "recursion" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +aggregator = { path = "../aggregator" } + +log.workspace = true +itertools.workspace = true +serde_json.workspace = true +rand.workspace = true +halo2_proofs.workspace = true +halo2curves.workspace = true +ce-snark-verifier.workspace = true +ce-snark-verifier-sdk.workspace = true diff --git a/recursion/src/circuit.rs b/recursion/src/circuit.rs new file mode 100644 index 0000000000..13e4f28562 --- /dev/null +++ b/recursion/src/circuit.rs @@ -0,0 +1,415 @@ +#![allow(clippy::type_complexity)] +use super::*; +use crate::{ + common::{poseidon, succinct_verify}, + types::{As, BaseFieldEccChip, PlonkSuccinctVerifier, Svk}, + SECURE_MDS, +}; +use aggregator::ConfigParams as RecursionCircuitConfigParams; +use ce_snark_verifier::{ + halo2_base::{ + gates::{ + circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, + GateInstructions, RangeInstructions, + }, + AssignedValue, + }, + loader::halo2::{ + halo2_ecc::{bn254::FpChip, ecc::EcPoint}, + EccInstructions, IntegerInstructions, + }, + pcs::{kzg::KzgAccumulator, AccumulationScheme, AccumulationSchemeProver}, + util::arithmetic::{fe_to_fe, fe_to_limbs}, + verifier::SnarkVerifier, +}; +use ce_snark_verifier_sdk::{ + halo2::{aggregation::Halo2Loader, PoseidonTranscript}, + Snark, BITS, LIMBS, +}; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, +}; +use rand::rngs::OsRng; +use std::{fs::File, iter, marker::PhantomData, mem, rc::Rc}; + +/// Select condition ? LHS : RHS. +fn select_accumulator<'a>( + loader: &Rc>, + condition: &AssignedValue, + lhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>>, +) -> Result>>, Error> { + let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] + .iter() + .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) + .map(|(lhs, rhs)| { + loader.ecc_chip().select( + loader.ctx_mut().main(), + EcPoint::clone(lhs), + EcPoint::clone(rhs), + *condition, + ) + }) + .collect::>() + .try_into() + .unwrap(); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) +} + +/// Accumulate a value into the current accumulator. +fn accumulate<'a>( + loader: &Rc>, + accumulators: Vec>>>, + as_proof: &[u8], +) -> KzgAccumulator>> { + let mut transcript = + PoseidonTranscript::, _>::new::(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() +} + +#[derive(Clone)] +pub struct RecursionCircuit { + /// The verifying key for the circuit. + svk: Svk, + /// The default accumulator to initialise the circuit. + default_accumulator: KzgAccumulator, + /// The SNARK witness from the k-th BatchCircuit. + app: Snark, + /// The SNARK witness from the previous RecursionCircuit, i.e. RecursionCircuit up to the + /// (k-1)-th BatchCircuit. + previous: Snark, + /// The recursion round, starting at round=0 and incrementing at every subsequent recursion. + round: usize, + /// The public inputs to the RecursionCircuit itself. + instances: Vec, + /// The accumulation of the SNARK proofs recursed over thus far. + as_proof: Vec, + + inner: BaseCircuitBuilder, + + _marker: PhantomData, +} + +impl RecursionCircuit { + /// The index of the preprocessed digest in the [`RecursionCircuit`]'s instances. Note that we + /// need a single cell to hold this value as it is a poseidon hash over the bn256 curve, hence + /// it fits within an [`Fr`] cell. + /// + /// [`Fr`]: halo2_proofs::halo2curves::bn256::Fr + const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; + + /// The index within the instances to find the "initial" state in the state transition. + const INITIAL_STATE_ROW: usize = Self::PREPROCESSED_DIGEST_ROW + 1; + + const STATE_ROW: usize = 4 * LIMBS + 2; + const ROUND_ROW: usize = 4 * LIMBS + 3; + + /// Construct a new instance of the [`RecursionCircuit`] given the SNARKs from the current and + /// previous [`BatchCircuit`], and the recursion round. + /// + /// [`BatchCircuit`]: aggregator::BatchCircuit + pub fn new( + params: &ParamsKZG, + app: Snark, + previous: Snark, + _rng: impl Rng + Send, + round: usize, + ) -> Self { + let svk = params.get_g()[0].into(); + let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); + + let succinct_verify = |snark: &Snark| { + let mut transcript = + PoseidonTranscript::::new::(snark.proof.as_slice()); + let proof = PlonkSuccinctVerifier::read_proof( + &svk, + &snark.protocol, + &snark.instances, + &mut transcript, + ) + .unwrap(); + PlonkSuccinctVerifier::verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + + let accumulators = iter::empty() + .chain(succinct_verify(&app)) + .chain( + (round > 0) + .then(|| succinct_verify(&previous)) + .unwrap_or_else(|| { + let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); + vec![default_accumulator.clone(); num_accumulator] + }), + ) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let preprocessed_digest = { + let inputs = previous + .protocol + .preprocessed + .iter() + .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) + .map(fe_to_fe) + .chain(previous.protocol.transcript_initial_state) + .collect_vec(); + poseidon(&NativeLoader, &inputs) + }; + + // TODO: allow more than 1 element for state. + let state = ST::state_indices() + .into_iter() + .map(|i| &app.instances[0][i]) + .chain( + ST::additional_indices() + .into_iter() + .map(|i| &app.instances[0][i]), + ) + .next() + .unwrap() + .clone(); + let initial_state = if round > 0 { + // pick from prev snark + Vec::from( + &previous.instances[0][Self::INITIAL_STATE_ROW + ..Self::INITIAL_STATE_ROW + ST::num_transition_instance()], + ) + } else { + // pick from app + ST::state_prev_indices() + .into_iter() + .map(|i| app.instances[0][i]) + .collect::>() + } + .first() + .unwrap() + .clone(); + + let instances = [ + accumulator.lhs.x, + accumulator.lhs.y, + accumulator.rhs.x, + accumulator.rhs.y, + ] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([ + preprocessed_digest, + initial_state, + state, + Fr::from(round as u64), + ]) + .collect(); + + let inner = BaseCircuitBuilder::new(false).use_params(load_base_circuit_params()); + let mut circuit = Self { + svk, + default_accumulator, + app, + previous, + round, + instances, + as_proof, + inner, + _marker: Default::default(), + }; + circuit.build(); + circuit + } + + fn build(&mut self) { + let range = self.inner.range_chip(); + let main_gate = range.gate(); + let pool = self.inner.pool(0); + let [preprocessed_digest, initial_state, state, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[Self::INITIAL_STATE_ROW], + self.instances[Self::STATE_ROW], + self.instances[Self::ROUND_ROW], + ] + .map(|instance| main_gate.assign_integer(pool, instance)); + let first_round = main_gate.is_zero(pool.main(), round); + let not_first_round = main_gate.not(pool.main(), first_round); + + let fp_chip = FpChip::::new(&range, BITS, LIMBS); + let ecc_chip = BaseFieldEccChip::new(&fp_chip); + let loader = Halo2Loader::new(ecc_chip, mem::take(self.inner.pool(0))); + let (mut app_instances, app_accumulators) = + succinct_verify(&self.svk, &loader, &self.app, None); + let (mut previous_instances, previous_accumulators) = succinct_verify( + &self.svk, + &loader, + &self.previous, + Some(preprocessed_digest), + ); + + let default_accmulator = self.load_default_accumulator(&loader).unwrap(); + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accmulator, + previous_accumulator, + ) + .unwrap() + }) + .collect::>(); + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + let lhs = lhs.into_assigned(); + let rhs = rhs.into_assigned(); + let app_instances = app_instances.pop().unwrap(); + let previous_instances = previous_instances.pop().unwrap(); + + let mut pool = loader.take_ctx(); + let ctx = pool.main(); + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + &main_gate.mul(ctx, preprocessed_digest, not_first_round), + &previous_instances[Self::PREPROCESSED_DIGEST_ROW], + ), + // Propagate initial_state + ( + &main_gate.mul(ctx, initial_state, not_first_round), + &previous_instances[Self::INITIAL_STATE_ROW], + ), + // Verify initial_state is same as the first application snark + ( + &main_gate.mul(ctx, initial_state, first_round), + &main_gate.mul(ctx, app_instances[0], first_round), + ), + // Verify current state is same as the current application snark + (&state, &app_instances[1]), + // Verify previous state is same as the current application snark + ( + &main_gate.mul(ctx, app_instances[0], not_first_round), + &previous_instances[Self::STATE_ROW], + ), + // Verify round is increased by 1 when not at first round + ( + &round, + &main_gate.add(ctx, not_first_round, previous_instances[Self::ROUND_ROW]), + ), + ] { + ctx.constrain_equal(lhs, rhs); + } + *self.inner.pool(0) = pool; + + self.inner.assigned_instances[0].extend( + [lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .chain([preprocessed_digest, initial_state, state, round].iter()) + .copied(), + ); + } + + fn as_proof(&self) -> &[u8] { + &self.as_proof + } + + fn load_default_accumulator<'a>( + &self, + loader: &Rc>, + ) -> Result>>, Error> { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + let assigned = loader + .ecc_chip() + .assign_constant(&mut loader.ctx_mut(), default); + loader.ec_point_from_assigned(assigned) + }); + Ok(KzgAccumulator::new(lhs, rhs)) + } + + /// Returns the number of instance cells in the Recursion Circuit, help to refine the CircuitExt + /// trait + pub fn num_instance_fixed() -> usize { + // [ + // ..lhs (accumulator LHS), + // ..rhs (accumulator RHS), + // preprocessed_digest, + // initial_state, + // state, + // round + // ] + 4 * LIMBS + 2 * ST::num_transition_instance() + ST::num_additional_instance() + 2 + } +} + +impl Circuit for RecursionCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = BaseCircuitParams; + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + Self::Config::configure(meta, load_base_circuit_params()) + } + + fn synthesize(&self, config: Self::Config, layouter: impl Layouter) -> Result<(), Error> { + self.inner.synthesize(config, layouter) + } +} + +impl CircuitExt for RecursionCircuit { + fn num_instance(&self) -> Vec { + vec![Self::num_instance_fixed()] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + + fn selectors(config: &Self::Config) -> Vec { + config.gate().basic_gates[0] + .iter() + .map(|gate| gate.q_enable) + .collect() + } +} + +fn load_base_circuit_params() -> BaseCircuitParams { + let path = std::env::var("BUNDLE_CONFIG") + .unwrap_or_else(|_| "configs/bundle_circuit.config".to_owned()); + let bundle_params: RecursionCircuitConfigParams = + serde_json::from_reader(File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}"))) + .unwrap(); + + BaseCircuitParams { + k: usize::try_from(bundle_params.degree).unwrap(), + lookup_bits: Some(bundle_params.lookup_bits), + num_lookup_advice_per_phase: bundle_params.num_lookup_advice, + num_advice_per_phase: bundle_params.num_advice, + num_fixed: bundle_params.num_fixed, + num_instance_columns: 1, + } +} diff --git a/aggregator/src/recursion/common.rs b/recursion/src/common.rs similarity index 50% rename from aggregator/src/recursion/common.rs rename to recursion/src/common.rs index dc18efce8a..eca6962f1e 100644 --- a/aggregator/src/recursion/common.rs +++ b/recursion/src/common.rs @@ -1,44 +1,42 @@ -use std::rc::Rc; - -use snark_verifier::{ - loader::halo2::EccInstructions, - pcs::{kzg::KzgAccumulator, MultiOpenScheme, PolynomialCommitmentScheme}, +use crate::{ + sv_halo2_base::AssignedValue, + types::{PlonkSuccinctVerifier, Svk}, + G1Affine, +}; +use ce_snark_verifier::{ + loader::{Loader, ScalarLoader}, + pcs::kzg::KzgAccumulator, util::hash, + verifier::SnarkVerifier, }; -use snark_verifier_sdk::{ - types::{BaseFieldEccChip, Halo2Loader, Plonk}, - SnarkWitness, +use ce_snark_verifier_sdk::{ + halo2::{aggregation::Halo2Loader, PoseidonTranscript, POSEIDON_SPEC}, + Snark, }; +use halo2curves::bn256::Fr; +use itertools::Itertools; +use std::rc::Rc; -use super::*; - -type AssignedScalar<'a> = >::AssignedScalar; - -fn poseidon>(loader: &L, inputs: &[L::LoadedScalar]) -> L::LoadedScalar { +pub fn poseidon>(loader: &L, inputs: &[L::LoadedScalar]) -> L::LoadedScalar { let mut hasher = hash::Poseidon::from_spec(loader, POSEIDON_SPEC.clone()); hasher.update(inputs); hasher.squeeze() } -/// It is similar to `succinct_verify` method inside of snark-verifier -/// but allow it allow loader to load preprocessed part as witness (so ANY circuit) -/// can be verified. -pub fn dynamic_verify<'a, PCS>( - svk: &PCS::SuccinctVerifyingKey, +const SECURE_MDS: usize = 0; + +pub fn succinct_verify<'a>( + svk: &Svk, loader: &Rc>, - snark: &SnarkWitness, - preprocessed_digest: Option>, -) -> (Vec>>, Vec) -where - PCS: PolynomialCommitmentScheme< - G1Affine, - Rc>, - Accumulator = KzgAccumulator>>, - > + MultiOpenScheme>>, -{ + snark: &Snark, + preprocessed_digest: Option>, +) -> ( + Vec>>, + Vec>>>, +) { let protocol = if let Some(preprocessed_digest) = preprocessed_digest { let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); - let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader, false); let inputs = protocol .preprocessed .iter() @@ -49,9 +47,7 @@ where }) .chain(protocol.transcript_initial_state.clone()) .collect_vec(); - loader - .assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest) - .unwrap(); + loader.assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest); protocol } else { snark.protocol.loaded(loader) @@ -67,9 +63,11 @@ where .collect_vec() }) .collect_vec(); - let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = Plonk::::read_proof(svk, &protocol, &instances, &mut transcript); - let accumulators = Plonk::::succinct_verify(svk, &protocol, &instances, &proof); + let mut transcript = + PoseidonTranscript::, _>::new::(loader, snark.proof()); + let proof = + PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof).unwrap(); ( instances diff --git a/aggregator/src/recursion.rs b/recursion/src/lib.rs similarity index 89% rename from aggregator/src/recursion.rs rename to recursion/src/lib.rs index 4b2c66dd27..d79d5a76b9 100644 --- a/aggregator/src/recursion.rs +++ b/recursion/src/lib.rs @@ -9,39 +9,36 @@ mod circuit; /// Common functionality utilised by the recursion circuit. mod common; -/// Config for recursion circuit -mod config; +// /// Config for recursion circuit +// mod config; /// Some utility functions. mod util; +/// Type aliases. +mod types; + pub use circuit::RecursionCircuit; -pub(crate) use common::dynamic_verify; +// pub(crate) use common::dynamic_verify; pub use util::{gen_recursion_pk, initial_recursion_snark}; +use ce_snark_verifier::{ + loader::{halo2::halo2_ecc::halo2_base as sv_halo2_base, native::NativeLoader}, + system::halo2::{compile, Config}, + // verifier::{PlonkProof, PlonkVerifier}, +}; +use ce_snark_verifier_sdk::{CircuitExt, BITS, LIMBS}; use halo2_proofs::{ halo2curves::{ - bn256::{Bn256, Fq, Fr, G1Affine}, + bn256::{Bn256, Fr, G1Affine}, group::ff::Field, }, plonk::{Circuit, ConstraintSystem, Error, ProvingKey, Selector, VerifyingKey}, }; use itertools::Itertools; use rand::Rng; -use snark_verifier::{ - loader::{ - halo2::halo2_ecc::halo2_base as sv_halo2_base, native::NativeLoader, Loader, ScalarLoader, - }, - system::halo2::{compile, Config}, - verifier::{PlonkProof, PlonkVerifier}, -}; -use snark_verifier_sdk::{ - types::{PoseidonTranscript, POSEIDON_SPEC}, - CircuitExt, Snark, -}; use sv_halo2_base::halo2_proofs; - -use crate::constants::{BITS, LIMBS}; +const SECURE_MDS: usize = 0; /// Any data that can be recursively bundled must implement the described state transition /// trait. diff --git a/aggregator/src/tests/recursion.rs b/recursion/src/tests.rs similarity index 100% rename from aggregator/src/tests/recursion.rs rename to recursion/src/tests.rs diff --git a/recursion/src/types.rs b/recursion/src/types.rs new file mode 100644 index 0000000000..0c041ab9f9 --- /dev/null +++ b/recursion/src/types.rs @@ -0,0 +1,18 @@ +use ce_snark_verifier::pcs::kzg::{Bdfg21, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}; +use ce_snark_verifier_sdk::{BITS, LIMBS}; +use halo2curves::bn256::{Bn256, G1Affine}; + +pub type Svk = KzgSuccinctVerifyingKey; +pub type As = KzgAs; + +use ce_snark_verifier::verifier::plonk; +pub type PlonkSuccinctVerifier = plonk::PlonkSuccinctVerifier>; + +use ce_snark_verifier::loader::halo2::halo2_ecc::ecc; +pub type BaseFieldEccChip<'chip> = ecc::BaseFieldEccChip<'chip, G1Affine>; + +// const T: usize = 3; +// const RATE: usize = 2; + +// use ce_snark_verifier::util::hash; +// pub type Poseidon = hash::Poseidon; diff --git a/aggregator/src/recursion/util.rs b/recursion/src/util.rs similarity index 92% rename from aggregator/src/recursion/util.rs rename to recursion/src/util.rs index c8931cd219..c6c661a792 100644 --- a/aggregator/src/recursion/util.rs +++ b/recursion/src/util.rs @@ -1,17 +1,18 @@ use std::path::Path; +use super::*; +use crate::SECURE_MDS; +use ce_snark_verifier::{ + pcs::kzg::{Bdfg21, KzgAs}, + util::{arithmetic::fe_to_limbs, transcript::TranscriptWrite}, + verifier::plonk::PlonkProof, +}; +use ce_snark_verifier_sdk::{gen_pk, halo2::PoseidonTranscript, CircuitExt, Snark}; use halo2_proofs::{ circuit::Layouter, plonk::keygen_vk, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; -use snark_verifier::{ - pcs::kzg::{Bdfg21, Kzg}, - util::{arithmetic::fe_to_limbs, transcript::TranscriptWrite}, -}; -use snark_verifier_sdk::{gen_pk, CircuitExt, Snark}; - -use super::*; mod dummy_circuit { use super::*; @@ -78,9 +79,9 @@ fn gen_dummy_snark>( num_instance: &[usize], mut rng: impl Rng + Send, ) -> Snark { - use snark_verifier::cost::CostEstimation; + use ce_snark_verifier::cost::CostEstimation; use std::iter; - type Pcs = Kzg; + type As = KzgAs; let protocol = compile( params, @@ -94,7 +95,7 @@ fn gen_dummy_snark>( .map(|&n| iter::repeat_with(|| Fr::random(&mut rng)).take(n).collect()) .collect(); let proof = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let mut transcript = PoseidonTranscript::::new::(Vec::new()); for _ in 0..protocol .num_witness .iter() @@ -108,8 +109,8 @@ fn gen_dummy_snark>( for _ in 0..protocol.evaluations.len() { transcript.write_scalar(Fr::random(&mut rng)).unwrap(); } - let queries = PlonkProof::::empty_queries(&protocol); - for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..As::estimate_cost(&queries).num_commitment { transcript .write_ec_point(G1Affine::random(&mut rng)) .unwrap();