From b5c15250c0f40d4310786775868dd9f4bb4088b3 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Thu, 3 Oct 2024 19:06:07 -0700 Subject: [PATCH] Add V2 Prover/Verifier to support optional AIRs --- benchmark/benches/single_rw.rs | 4 +- sdk/src/interaction/dummy_interaction_air.rs | 8 +- stark-backend/src/keygen/mod.rs | 6 +- stark-backend/src/keygen/v2/mod.rs | 269 ++++++++++++++ stark-backend/src/keygen/v2/types.rs | 99 ++++++ stark-backend/src/keygen/v2/view.rs | 103 ++++++ stark-backend/src/prover/mod.rs | 86 +++-- stark-backend/src/prover/quotient/helper.rs | 42 +++ stark-backend/src/prover/quotient/mod.rs | 41 ++- stark-backend/src/prover/v2/mod.rs | 354 +++++++++++++++++++ stark-backend/src/prover/v2/trace.rs | 208 +++++++++++ stark-backend/src/prover/v2/types.rs | 108 ++++++ stark-backend/src/verifier/constraints.rs | 10 +- stark-backend/src/verifier/mod.rs | 5 +- stark-backend/src/verifier/v2/mod.rs | 311 ++++++++++++++++ stark-backend/tests/cached_lookup/mod.rs | 95 +++-- 16 files changed, 1644 insertions(+), 105 deletions(-) create mode 100644 stark-backend/src/keygen/v2/mod.rs create mode 100644 stark-backend/src/keygen/v2/types.rs create mode 100644 stark-backend/src/keygen/v2/view.rs create mode 100644 stark-backend/src/prover/quotient/helper.rs create mode 100644 stark-backend/src/prover/v2/mod.rs create mode 100644 stark-backend/src/prover/v2/trace.rs create mode 100644 stark-backend/src/prover/v2/types.rs create mode 100644 stark-backend/src/verifier/v2/mod.rs diff --git a/benchmark/benches/single_rw.rs b/benchmark/benches/single_rw.rs index c2c2aaed58..5557822fa9 100644 --- a/benchmark/benches/single_rw.rs +++ b/benchmark/benches/single_rw.rs @@ -449,8 +449,8 @@ pub fn prove_raps_with_committed_traces_with_groups<'a, SC: StarkGenericConfig>( b.iter(|| { let _ = quotient_committer.quotient_values( raps.clone(), - pk, - trace_views.clone(), + &pk.get_quotient_vk_data_per_air(), + &trace_views, public_values, ); }) diff --git a/sdk/src/interaction/dummy_interaction_air.rs b/sdk/src/interaction/dummy_interaction_air.rs index da5c41aca6..9906952426 100644 --- a/sdk/src/interaction/dummy_interaction_air.rs +++ b/sdk/src/interaction/dummy_interaction_air.rs @@ -58,14 +58,14 @@ impl BaseAirWithPublicValues for DummyInteractionAir {} impl PartitionedBaseAir for DummyInteractionAir { fn cached_main_widths(&self) -> Vec { if self.partition { - vec![1] + vec![self.field_width] } else { vec![] } } fn common_main_width(&self) -> usize { if self.partition { - self.field_width + 1 } else { 1 + self.field_width } @@ -84,8 +84,8 @@ impl BaseAir for DummyInteractionAir { impl Air for DummyInteractionAir { fn eval(&self, builder: &mut AB) { let (fields, count) = if self.partition { - let local_0 = builder.cached_mains()[0].row_slice(0); - let local_1 = builder.common_main().row_slice(0); + let local_0 = builder.common_main().row_slice(0); + let local_1 = builder.cached_mains()[0].row_slice(0); let count = local_0[0]; let fields = local_1.to_vec(); (fields, count) diff --git a/stark-backend/src/keygen/mod.rs b/stark-backend/src/keygen/mod.rs index ed15ce7863..070d580d82 100644 --- a/stark-backend/src/keygen/mod.rs +++ b/stark-backend/src/keygen/mod.rs @@ -8,6 +8,7 @@ use p3_uni_stark::{StarkGenericConfig, Val}; use tracing::instrument; pub mod types; +pub mod v2; use self::types::{ create_commit_to_air_graph, MultiStarkProvingKey, ProverOnlySinglePreprocessedData, @@ -242,11 +243,6 @@ impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> { .push((air, partitioned_main_ptrs, interaction_chunk_size)); } - /// Default way to add a single Interactive AIR. - /// DO NOT use this if the main trace needs to be partitioned. - /// - `degree` is height of trace matrix - /// - Generates preprocessed trace and creates a dedicated commitment for it. - /// - Adds main trace to the default shared main trace commitment. #[instrument(level = "debug", skip_all)] pub fn get_single_preprocessed_data( &self, diff --git a/stark-backend/src/keygen/v2/mod.rs b/stark-backend/src/keygen/v2/mod.rs new file mode 100644 index 0000000000..6995f3a871 --- /dev/null +++ b/stark-backend/src/keygen/v2/mod.rs @@ -0,0 +1,269 @@ +use itertools::Itertools; +use p3_field::AbstractExtensionField; +use p3_matrix::Matrix; +use p3_uni_stark::{StarkGenericConfig, Val}; +use tracing::instrument; + +use crate::{ + air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder}, + keygen::{ + types::{ProverOnlySinglePreprocessedData, TraceWidth, VerifierSinglePreprocessedData}, + v2::types::{MultiStarkProvingKeyV2, StarkProvingKeyV2, StarkVerifyingKeyV2}, + }, + prover::trace::TraceCommitter, + rap::AnyRap, +}; + +pub mod types; +pub(crate) mod view; + +struct AIRKeygenBuilder<'a, SC: StarkGenericConfig> { + air: &'a dyn AnyRap, + prep_keygen_data: PrepKeygenData, + interaction_chunk_size: Option, +} + +/// Stateful builder to create multi-stark proving and verifying keys +/// for system of multiple RAPs with multiple multi-matrix commitments +pub struct MultiStarkKeygenBuilderV2<'a, SC: StarkGenericConfig> { + pub config: &'a SC, + /// Information for partitioned AIRs. + partitioned_airs: Vec>, +} + +impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilderV2<'a, SC> { + pub fn new(config: &'a SC) -> Self { + Self { + config, + partitioned_airs: vec![], + } + } + + /// Default way to add a single Interactive AIR. + /// Returns `air_id` + #[instrument(level = "debug", skip_all)] + pub fn add_air(&mut self, air: &'a dyn AnyRap) -> usize { + self.add_air_with_interaction_chunk_size(air, None) + } + + /// Add a single Interactive AIR with a specified interaction chunk size. + /// Returns `air_id` + pub fn add_air_with_interaction_chunk_size( + &mut self, + air: &'a dyn AnyRap, + interaction_chunk_size: Option, + ) -> usize { + self.partitioned_airs.push(AIRKeygenBuilder::new( + self.config.pcs(), + air, + interaction_chunk_size, + )); + self.partitioned_airs.len() - 1 + } + + /// Consume the builder and generate proving key. + /// The verifying key can be obtained from the proving key. + pub fn generate_pk(self) -> MultiStarkProvingKeyV2 { + let global_max_constraint_degree = self + .partitioned_airs + .iter() + .map(|keygen_builder| { + let max_constraint_degree = keygen_builder.max_constraint_degree(); + tracing::debug!( + "{} has constraint degree {}", + keygen_builder.air.name(), + max_constraint_degree + ); + max_constraint_degree + }) + .max() + .unwrap(); + tracing::info!( + "Max constraint (excluding logup constraints) degree across all AIRs: {}", + global_max_constraint_degree + ); + + let pk_per_air: Vec<_> = self + .partitioned_airs + .into_iter() + .map(|keygen_builder| keygen_builder.generate_pk(global_max_constraint_degree)) + .collect(); + + for pk in pk_per_air.iter() { + let width = &pk.vk.params.width; + tracing::info!("{:<20} | Quotient Deg = {:<2} | Prep Cols = {:<2} | Main Cols = {:<8} | Perm Cols = {:<4} | {:<4} Constraints | {:<3} Interactions On Buses {:?}", + pk.air_name, + pk.vk.quotient_degree, + width.preprocessed.unwrap_or(0), + format!("{:?}",width.main_widths()), + format!("{:?}",width.after_challenge.iter().map(|&x| x * >>::D).collect_vec()), + pk.vk.symbolic_constraints.constraints.len(), + pk.vk.symbolic_constraints.interactions.len(), + pk.vk + .symbolic_constraints + .interactions + .iter() + .map(|i| i.bus_index) + .collect_vec() + ); + #[cfg(feature = "bench-metrics")] + { + let labels = [("air_name", pk.air_name.clone())]; + metrics::counter!("quotient_deg", &labels).absolute(pk.vk.quotient_degree as u64); + // column info will be logged by prover later + metrics::counter!("constraints", &labels) + .absolute(pk.vk.symbolic_constraints.constraints.len() as u64); + metrics::counter!("interactions", &labels) + .absolute(pk.vk.symbolic_constraints.interactions.len() as u64); + } + } + + MultiStarkProvingKeyV2 { + per_air: pk_per_air, + } + } +} + +impl<'a, SC: StarkGenericConfig> AIRKeygenBuilder<'a, SC> { + fn new(pcs: &SC::Pcs, air: &'a dyn AnyRap, interaction_chunk_size: Option) -> Self { + let prep_keygen_data = compute_prep_data_for_air(pcs, air); + AIRKeygenBuilder { + air, + prep_keygen_data, + interaction_chunk_size, + } + } + + fn max_constraint_degree(&self) -> usize { + self.get_symbolic_builder() + .constraints() + .max_constraint_degree() + } + + fn generate_pk(mut self, max_constraint_degree: usize) -> StarkProvingKeyV2 { + let air_name = self.air.name(); + self.find_interaction_chunk_size(max_constraint_degree); + + let symbolic_builder = self.get_symbolic_builder(); + let params = symbolic_builder.params(); + let symbolic_constraints = symbolic_builder.constraints(); + let log_quotient_degree = symbolic_constraints.get_log_quotient_degree(); + let quotient_degree = 1 << log_quotient_degree; + + let Self { + prep_keygen_data: + PrepKeygenData { + verifier_data: prep_verifier_data, + prover_data: prep_prover_data, + }, + interaction_chunk_size, + .. + } = self; + let interaction_chunk_size = interaction_chunk_size + .expect("Interaction chunk size should be set before generating proving key"); + + let vk = StarkVerifyingKeyV2 { + preprocessed_data: prep_verifier_data, + params, + symbolic_constraints, + quotient_degree, + interaction_chunk_size, + }; + StarkProvingKeyV2 { + air_name, + vk, + preprocessed_data: prep_prover_data, + } + } + + /// Finds the interaction chunk size for the AIR if it is not provided. + /// `global_max_constraint_degree` is the maximum constraint degree across all AIRs. + /// The degree of the dominating logup constraint is bounded by + /// logup_degree = max(1 + max_field_degree * interaction_chunk_size, + /// max_count_degree + max_field_degree * (interaction_chunk_size - 1)) + /// More details about this can be found in the function eval_permutation_constraints + /// + /// The goal is to pick interaction_chunk_size so that logup_degree does not + /// exceed max_constraint_degree (if possible), while maximizing interaction_chunk_size + fn find_interaction_chunk_size(&mut self, global_max_constraint_degree: usize) { + if self.interaction_chunk_size.is_some() { + return; + } + + let (max_field_degree, max_count_degree) = self + .get_symbolic_builder() + .constraints() + .max_interaction_degrees(); + + let interaction_chunk_size = if max_field_degree == 0 { + 1 + } else { + let mut interaction_chunk_size = (global_max_constraint_degree - 1) / max_field_degree; + interaction_chunk_size = interaction_chunk_size.min( + (global_max_constraint_degree - max_count_degree + max_field_degree) + / max_field_degree, + ); + interaction_chunk_size = interaction_chunk_size.max(1); + interaction_chunk_size + }; + + self.interaction_chunk_size = Some(interaction_chunk_size); + } + + fn get_symbolic_builder(&self) -> SymbolicRapBuilder> { + let width = TraceWidth { + preprocessed: self.prep_keygen_data.width(), + cached_mains: self.air.cached_main_widths(), + common_main: self.air.common_main_width(), + after_challenge: vec![], + }; + get_symbolic_builder( + self.air, + &width, + &[], + &[], + self.interaction_chunk_size.unwrap_or(1), + ) + } +} + +pub(super) struct PrepKeygenData { + pub verifier_data: Option>, + pub prover_data: Option>, +} + +impl PrepKeygenData { + pub fn width(&self) -> Option { + self.prover_data.as_ref().map(|d| d.trace.width()) + } +} + +fn compute_prep_data_for_air( + pcs: &SC::Pcs, + air: &dyn AnyRap, +) -> PrepKeygenData { + let preprocessed_trace = air.preprocessed_trace(); + let vpdata_opt = preprocessed_trace.map(|trace| { + let trace_committer = TraceCommitter::::new(pcs); + let data = trace_committer.commit(vec![trace.clone()]); + let vdata = VerifierSinglePreprocessedData { + commit: data.commit, + }; + let pdata = ProverOnlySinglePreprocessedData { + trace, + data: data.data, + }; + (vdata, pdata) + }); + if let Some((vdata, pdata)) = vpdata_opt { + PrepKeygenData { + prover_data: Some(pdata), + verifier_data: Some(vdata), + } + } else { + PrepKeygenData { + prover_data: None, + verifier_data: None, + } + } +} diff --git a/stark-backend/src/keygen/v2/types.rs b/stark-backend/src/keygen/v2/types.rs new file mode 100644 index 0000000000..dcb0623f45 --- /dev/null +++ b/stark-backend/src/keygen/v2/types.rs @@ -0,0 +1,99 @@ +/// Keygen V2 API for STARK backend +/// Changes: +/// - All AIRs can be optional +use derivative::Derivative; +use p3_uni_stark::{StarkGenericConfig, Val}; +use serde::{Deserialize, Serialize}; + +use crate::{ + air_builders::symbolic::SymbolicConstraints, + config::{Com, PcsProverData}, + keygen::types::{ + ProverOnlySinglePreprocessedData, StarkVerifyingParams, VerifierSinglePreprocessedData, + }, +}; + +/// Verifying key for a single STARK (corresponding to single AIR matrix) +#[derive(Derivative, Serialize, Deserialize)] +#[derivative(Clone(bound = "Com: Clone"))] +#[serde(bound( + serialize = "Com: Serialize", + deserialize = "Com: Deserialize<'de>" +))] +pub struct StarkVerifyingKeyV2 { + /// Preprocessed trace data, if any + pub preprocessed_data: Option>, + /// Parameters of the STARK + pub params: StarkVerifyingParams, + /// Symbolic constraints of the AIR in all challenge phases. This is + /// a serialization of the constraints in the AIR. + pub symbolic_constraints: SymbolicConstraints>, + /// The factor to multiple the trace degree by to get the degree of the quotient polynomial. Determined from the max constraint degree of the AIR constraints. + /// This is equivalently the number of chunks the quotient polynomial is split into. + pub quotient_degree: usize, + /// Number of interactions to bundle in permutation trace + pub interaction_chunk_size: usize, +} + +/// Common verifying key for multiple AIRs. +/// +/// This struct contains the necessary data for the verifier to verify proofs generated for +/// multiple AIRs using a single verifying key. +#[derive(Serialize, Deserialize)] +#[serde(bound( + serialize = "Com: Serialize", + deserialize = "Com: Deserialize<'de>" +))] +pub struct MultiStarkVerifyingKeyV2 { + pub per_air: Vec>, +} + +/// Proving key for a single STARK (corresponding to single AIR matrix) +#[derive(Serialize, Deserialize, Clone)] +#[serde(bound( + serialize = "PcsProverData: Serialize", + deserialize = "PcsProverData: Deserialize<'de>" +))] +pub struct StarkProvingKeyV2 { + /// Type name of the AIR, for display purposes only + pub air_name: String, + /// Verifying key + pub vk: StarkVerifyingKeyV2, + /// Prover only data for preprocessed trace + pub preprocessed_data: Option>, +} + +/// Common proving key for multiple AIRs. +/// +/// This struct contains the necessary data for the prover to generate proofs for multiple AIRs +/// using a single proving key. +#[derive(Serialize, Deserialize)] +#[serde(bound( + serialize = "PcsProverData: Serialize", + deserialize = "PcsProverData: Deserialize<'de>" +))] +pub struct MultiStarkProvingKeyV2 { + pub per_air: Vec>, +} + +impl StarkVerifyingKeyV2 { + pub fn num_cached_mains(&self) -> usize { + self.params.width.cached_mains.len() + } + + pub fn has_common_main(&self) -> bool { + self.params.width.common_main != 0 + } + + pub fn has_interaction(&self) -> bool { + !self.symbolic_constraints.interactions.is_empty() + } +} + +impl MultiStarkProvingKeyV2 { + pub fn get_vk(&self) -> MultiStarkVerifyingKeyV2 { + MultiStarkVerifyingKeyV2 { + per_air: self.per_air.iter().map(|pk| pk.vk.clone()).collect(), + } + } +} diff --git a/stark-backend/src/keygen/v2/view.rs b/stark-backend/src/keygen/v2/view.rs new file mode 100644 index 0000000000..085a6f2d8d --- /dev/null +++ b/stark-backend/src/keygen/v2/view.rs @@ -0,0 +1,103 @@ +use itertools::Itertools; +use p3_challenger::FieldChallenger; +use p3_uni_stark::StarkGenericConfig; + +use crate::{ + config::Com, + keygen::v2::types::{ + MultiStarkProvingKeyV2, MultiStarkVerifyingKeyV2, StarkProvingKeyV2, StarkVerifyingKeyV2, + }, +}; + +pub(crate) struct MultiStarkVerifyingKeyV2View<'a, SC: StarkGenericConfig> { + pub per_air: Vec<&'a StarkVerifyingKeyV2>, +} + +pub(crate) struct MultiStarkProvingKeyV2View<'a, SC: StarkGenericConfig> { + pub air_ids: Vec, + pub per_air: Vec<&'a StarkProvingKeyV2>, +} + +impl MultiStarkVerifyingKeyV2 { + pub(crate) fn view(&self, air_ids: &[usize]) -> MultiStarkVerifyingKeyV2View { + MultiStarkVerifyingKeyV2View { + per_air: air_ids.iter().map(|&id| &self.per_air[id]).collect(), + } + } +} +impl MultiStarkProvingKeyV2 { + pub(crate) fn view(&self, air_ids: Vec) -> MultiStarkProvingKeyV2View { + let per_air = air_ids.iter().map(|&id| &self.per_air[id]).collect(); + MultiStarkProvingKeyV2View { air_ids, per_air } + } +} + +impl<'a, SC: StarkGenericConfig> MultiStarkVerifyingKeyV2View<'a, SC> { + /// Returns the preprocessed commit of each AIR. If the AIR does not have a preprocessed trace, returns None. + pub fn preprocessed_commits(&self) -> Vec>> { + self.per_air + .iter() + .map(|vk| { + vk.preprocessed_data + .as_ref() + .map(|data| data.commit.clone()) + }) + .collect() + } + /// Returns all non-empty preprocessed commits. + pub fn flattened_preprocessed_commits(&self) -> Vec> { + self.preprocessed_commits().into_iter().flatten().collect() + } + /// Returns challenges of each phase. + pub fn sample_challenges(&self, challenger: &mut SC::Challenger) -> Vec> { + // Generate 2 permutation challenges + let num_challenges_to_sample = self.num_challenges_to_sample(); + assert!(num_challenges_to_sample.len() <= 1); + num_challenges_to_sample + .iter() + .map(|&num_challenges| { + (0..num_challenges) + .map(|_| challenger.sample_ext_element::()) + .collect_vec() + }) + .collect() + } + pub fn num_phases(&self) -> usize { + self.per_air + .iter() + .map(|vk| { + // Consistency check + let num = vk.params.width.after_challenge.len(); + assert_eq!(num, vk.params.num_challenges_to_sample.len()); + assert_eq!(num, vk.params.num_exposed_values_after_challenge.len()); + num + }) + .max() + .unwrap_or(0) + } + pub fn num_challenges_to_sample(&self) -> Vec { + let num_phases = self.num_phases(); + (0..num_phases) + .map(|phase_idx| { + self.per_air + .iter() + .map(|vk| { + *vk.params + .num_challenges_to_sample + .get(phase_idx) + .unwrap_or(&0) + }) + .max() + .unwrap_or_else(|| panic!("No challenges used in challenge phase {phase_idx}")) + }) + .collect() + } +} + +impl<'a, SC: StarkGenericConfig> MultiStarkProvingKeyV2View<'a, SC> { + pub fn vk_view(&self) -> MultiStarkVerifyingKeyV2View { + MultiStarkVerifyingKeyV2View { + per_air: self.per_air.iter().map(|pk| &pk.vk).collect(), + } + } +} diff --git a/stark-backend/src/prover/mod.rs b/stark-backend/src/prover/mod.rs index 0c4701d0de..0219f21e23 100644 --- a/stark-backend/src/prover/mod.rs +++ b/stark-backend/src/prover/mod.rs @@ -5,7 +5,7 @@ use metrics::trace_metrics; use p3_challenger::{CanObserve, FieldChallenger}; use p3_commit::{Pcs, PolynomialSpace}; use p3_field::AbstractExtensionField; -use p3_matrix::Matrix; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_maybe_rayon::prelude::*; use p3_uni_stark::{Domain, StarkGenericConfig, Val}; use tracing::instrument; @@ -22,7 +22,7 @@ use crate::{ config::{Com, PcsProof, PcsProverData}, interaction::trace::generate_permutation_trace, keygen::types::MultiStarkProvingKey, - prover::trace::SingleRapCommittedTraceView, + prover::trace::{ProverTraceData, SingleRapCommittedTraceView}, rap::AnyRap, }; @@ -35,6 +35,7 @@ pub mod quotient; /// Trace commitment computation pub mod trace; pub mod types; +pub mod v2; thread_local! { pub static USE_DEBUG_BUILDER: Arc> = Arc::new(Mutex::new(true)); @@ -203,23 +204,21 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { // Commit to permutation traces: this means only 1 challenge round right now // One shared commit for all permutation traces let perm_pcs_data = tracing::info_span!("commit to permutation traces").in_scope(|| { - let flattened_traces_with_domains: Vec<_> = perm_traces - .into_iter() - .zip_eq(&main_trace_data.air_traces) - .flat_map(|(perm_trace, data)| { - perm_trace.map(|trace| (data.domain, trace.flatten_to_base())) - }) - .collect(); - // Only commit if there are permutation traces - if !flattened_traces_with_domains.is_empty() { - let (commit, data) = pcs.commit(flattened_traces_with_domains); - // Challenger observes commitment - challenger.observe(commit.clone()); - Some((commit, data)) - } else { - None - } + commit_perm_traces::( + pcs, + perm_traces, + &main_trace_data + .air_traces + .iter() + .map(|t| t.domain) + .collect_vec(), + ) }); + // Challenger observes commitment if exists + if let Some(data) = &perm_pcs_data { + challenger.observe(data.commit.clone()); + } + // Either 0 or 1 after_challenge commits, depending on if there are any permutation traces let after_challenge_pcs_data: Vec<_> = perm_pcs_data.into_iter().collect(); let main_pcs_data = &main_trace_data.pcs_data; @@ -253,7 +252,10 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { // There will be either 0 or 1 after_challenge traces let after_challenge = if let Some((cumulative_sum, index)) = cumulative_sum_and_index { - let matrix = CommittedSingleMatrixView::new(&after_challenge_pcs_data[0].1, index); + let matrix = CommittedSingleMatrixView::new( + after_challenge_pcs_data[0].data.as_ref(), + index, + ); let exposed_values = vec![cumulative_sum]; vec![(matrix, exposed_values)] } else { @@ -270,7 +272,8 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { .unzip(); // === END of logic specific to Interactions/permutations, we can now deal with general RAP === - self.prove_raps_with_committed_traces( + Self::prove_raps_with_committed_traces( + pcs, challenger, pk, raps, @@ -281,7 +284,6 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { public_values, ) } - /// Proves general RAPs after all traces have been committed. /// Soundness depends on `challenger` having already observed /// public values, exposed values after challenge, and all @@ -296,13 +298,13 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { #[allow(clippy::too_many_arguments)] #[instrument(level = "info", skip_all)] pub fn prove_raps_with_committed_traces<'a>( - &self, + pcs: &SC::Pcs, challenger: &mut SC::Challenger, pk: &'a MultiStarkProvingKey, raps: Vec<&'a dyn AnyRap>, trace_views: Vec>, main_pcs_data: &[(Com, &PcsProverData)], - after_challenge_pcs_data: &[(Com, PcsProverData)], + after_challenge_pcs_data: &[ProverTraceData], challenges: &[Vec], public_values: &'a [Vec>], ) -> Proof @@ -314,10 +316,9 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { SC::Challenge: Send + Sync, PcsProof: Send + Sync, { - let pcs = self.config.pcs(); let after_challenge_commitments: Vec<_> = after_challenge_pcs_data .iter() - .map(|(commit, _)| commit.clone()) + .map(|data| data.commit.clone()) .collect(); // Generate `alpha` challenge @@ -328,14 +329,11 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { .iter() .map(|view| view.domain.size()) .collect_vec(); - let quotient_degrees = pk - .per_air - .iter() - .map(|pk| pk.vk.quotient_degree) - .collect_vec(); + let qvks = pk.get_quotient_vk_data_per_air(); + let quotient_degrees = qvks.iter().map(|qvk| qvk.quotient_degree).collect_vec(); let quotient_committer = QuotientCommitter::new(pcs, challenges, alpha); let quotient_values = - quotient_committer.quotient_values(raps, pk, trace_views.clone(), public_values); + quotient_committer.quotient_values(raps, &qvks, &trace_views, public_values); // Commit to quotient polynomias. One shared commit for all quotient polynomials let quotient_data = quotient_committer.commit(quotient_values); @@ -383,12 +381,12 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { let after_challenge_data: Vec<_> = after_challenge_pcs_data .iter() .enumerate() - .map(|(round, (_, data))| { + .map(|(round, data)| { let domains = trace_views .iter() .flat_map(|view| (view.after_challenge.len() > round).then_some(view.domain)) .collect_vec(); - (data, domains) + (data.data.as_ref(), domains) }) .collect(); @@ -424,3 +422,25 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkProver<'c, SC> { } } } + +fn commit_perm_traces( + pcs: &SC::Pcs, + perm_traces: Vec>>, + domain_per_air: &[Domain], +) -> Option> { + let flattened_traces_with_domains: Vec<_> = perm_traces + .into_iter() + .zip_eq(domain_per_air) + .flat_map(|(perm_trace, domain)| perm_trace.map(|trace| (*domain, trace.flatten_to_base()))) + .collect(); + // Only commit if there are permutation traces + if !flattened_traces_with_domains.is_empty() { + let (commit, data) = pcs.commit(flattened_traces_with_domains); + Some(ProverTraceData { + commit, + data: data.into(), + }) + } else { + None + } +} diff --git a/stark-backend/src/prover/quotient/helper.rs b/stark-backend/src/prover/quotient/helper.rs new file mode 100644 index 0000000000..1c3b3de3ab --- /dev/null +++ b/stark-backend/src/prover/quotient/helper.rs @@ -0,0 +1,42 @@ +use p3_uni_stark::StarkGenericConfig; + +use crate::{ + keygen::{ + types::{MultiStarkProvingKey, StarkProvingKey}, + v2::types::StarkProvingKeyV2, + }, + prover::quotient::QuotientVKData, +}; + +pub(crate) trait QuotientVKDataHelper { + fn get_quotient_vk_data(&self) -> QuotientVKData; +} + +impl QuotientVKDataHelper for StarkProvingKeyV2 { + fn get_quotient_vk_data(&self) -> QuotientVKData { + QuotientVKData { + quotient_degree: self.vk.quotient_degree, + interaction_chunk_size: self.vk.interaction_chunk_size, + symbolic_constraints: &self.vk.symbolic_constraints, + } + } +} + +impl QuotientVKDataHelper for StarkProvingKey { + fn get_quotient_vk_data(&self) -> QuotientVKData { + QuotientVKData { + quotient_degree: self.vk.quotient_degree, + interaction_chunk_size: self.vk.interaction_chunk_size, + symbolic_constraints: &self.vk.symbolic_constraints, + } + } +} + +impl MultiStarkProvingKey { + pub fn get_quotient_vk_data_per_air(&self) -> Vec> { + self.per_air + .iter() + .map(|pk| pk.get_quotient_vk_data()) + .collect() + } +} diff --git a/stark-backend/src/prover/quotient/mod.rs b/stark-backend/src/prover/quotient/mod.rs index 7b82e3afab..f918adae09 100644 --- a/stark-backend/src/prover/quotient/mod.rs +++ b/stark-backend/src/prover/quotient/mod.rs @@ -9,12 +9,12 @@ use tracing::instrument; use self::single::compute_single_rap_quotient_values; use super::trace::SingleRapCommittedTraceView; use crate::{ - air_builders::prover::ProverConstraintFolder, + air_builders::{prover::ProverConstraintFolder, symbolic::SymbolicConstraints}, config::{Com, PcsProverData}, - keygen::types::{MultiStarkProvingKey, StarkVerifyingKey}, rap::{AnyRap, PartitionedBaseAir, Rap}, }; +pub(crate) mod helper; pub mod single; pub struct QuotientCommitter<'pcs, SC: StarkGenericConfig> { @@ -58,8 +58,8 @@ impl<'pcs, SC: StarkGenericConfig> QuotientCommitter<'pcs, SC> { pub fn quotient_values<'a>( &self, raps: Vec<&'a dyn AnyRap>, - pk: &'a MultiStarkProvingKey, - traces: Vec>, + qvks: &[QuotientVKData<'a, SC>], + traces: &[SingleRapCommittedTraceView<'a, SC>], public_values: &'a [Vec>], ) -> QuotientData where @@ -69,22 +69,22 @@ impl<'pcs, SC: StarkGenericConfig> QuotientCommitter<'pcs, SC> { Com: Send + Sync, { #[cfg(feature = "parallel")] - let inner = (raps, &pk.per_air, traces, public_values) + let inner = (raps, qvks, traces, public_values) .into_par_iter() // uses rayon multizip - .map(|(rap, pk, trace, pis)| self.single_rap_quotient_values(rap, &pk.vk, trace, pis)) + .map(|(rap, qvk, trace, pis)| self.single_rap_quotient_values(rap, qvk, trace, pis)) .collect(); #[cfg(not(feature = "parallel"))] - let inner = itertools::izip!(raps, &pk.per_air, traces, public_values) - .map(|(rap, pk, trace, pis)| self.single_rap_quotient_values(rap, &pk.vk, trace, pis)) + let inner = itertools::izip!(raps, qvks, traces, public_values) + .map(|(rap, qvk, trace, pis)| self.single_rap_quotient_values(rap, qvk, trace, pis)) .collect(); QuotientData { inner } } - pub fn single_rap_quotient_values<'a, R>( + pub(crate) fn single_rap_quotient_values<'a, R>( &self, rap: &'a R, - vk: &StarkVerifyingKey, - trace: SingleRapCommittedTraceView<'a, SC>, + qvk: &QuotientVKData<'a, SC>, + trace: &SingleRapCommittedTraceView<'a, SC>, public_values: &'a [Val], ) -> SingleQuotientData where @@ -93,12 +93,12 @@ impl<'pcs, SC: StarkGenericConfig> QuotientCommitter<'pcs, SC> { + Sync + ?Sized, { - let quotient_degree = vk.quotient_degree; + let quotient_degree = qvk.quotient_degree; let trace_domain = trace.domain; let quotient_domain = trace_domain.create_disjoint_domain(trace_domain.size() * quotient_degree); // Empty matrix if no preprocessed trace - let preprocessed_lde_on_quotient_domain = if let Some(view) = trace.preprocessed { + let preprocessed_lde_on_quotient_domain = if let Some(view) = trace.preprocessed.as_ref() { self.pcs .get_evaluations_on_domain(view.data, view.matrix_index, quotient_domain) .to_row_major_matrix() @@ -107,7 +107,7 @@ impl<'pcs, SC: StarkGenericConfig> QuotientCommitter<'pcs, SC> { }; let partitioned_main_lde_on_quotient_domain: Vec<_> = trace .partitioned_main - .into_iter() + .iter() .map(|view| { self.pcs .get_evaluations_on_domain(view.data, view.matrix_index, quotient_domain) @@ -136,7 +136,7 @@ impl<'pcs, SC: StarkGenericConfig> QuotientCommitter<'pcs, SC> { let quotient_values = compute_single_rap_quotient_values( rap, - &vk.symbolic_constraints, + qvk.symbolic_constraints, trace_domain, quotient_domain, preprocessed_lde_on_quotient_domain, @@ -149,7 +149,7 @@ impl<'pcs, SC: StarkGenericConfig> QuotientCommitter<'pcs, SC> { .iter() .map(|v| v.as_slice()) .collect_vec(), - vk.interaction_chunk_size, + qvk.interaction_chunk_size, ); SingleQuotientData { quotient_degree, @@ -243,3 +243,12 @@ pub struct QuotientChunk { /// and number of columns equal to extension field degree. pub chunk: RowMajorMatrix>, } + +/// All necessary data from VK to compute ProverQuotientData +pub struct QuotientVKData<'a, SC: StarkGenericConfig> { + pub quotient_degree: usize, + pub interaction_chunk_size: usize, + /// Symbolic constraints of the AIR in all challenge phases. This is + /// a serialization of the constraints in the AIR. + pub symbolic_constraints: &'a SymbolicConstraints>, +} diff --git a/stark-backend/src/prover/v2/mod.rs b/stark-backend/src/prover/v2/mod.rs new file mode 100644 index 0000000000..2567279ef7 --- /dev/null +++ b/stark-backend/src/prover/v2/mod.rs @@ -0,0 +1,354 @@ +use std::iter; + +use itertools::{izip, multiunzip, Itertools}; +use p3_challenger::{CanObserve, FieldChallenger}; +use p3_commit::{Pcs, PolynomialSpace}; +use p3_field::AbstractExtensionField; +use p3_matrix::Matrix; +use p3_uni_stark::{Domain, StarkGenericConfig, Val}; +use tracing::instrument; + +use crate::{ + config::{Com, PcsProof, PcsProverData}, + keygen::v2::{types::MultiStarkProvingKeyV2, view::MultiStarkProvingKeyV2View}, + prover::{ + opener::OpeningProver, + quotient::ProverQuotientData, + trace::{ProverTraceData, TraceCommitter}, + types::Commitments, + v2::{ + trace::{commit_permutation_traces, commit_quotient_traces}, + types::{AIRProofData, ProofInput, ProofV2}, + }, + }, +}; + +mod trace; +pub mod types; + +/// Proves multiple chips with interactions together. +/// This prover implementation is specialized for Interactive AIRs. +pub struct MultiTraceStarkProverV2<'c, SC: StarkGenericConfig> { + pub config: &'c SC, +} + +impl<'c, SC: StarkGenericConfig> MultiTraceStarkProverV2<'c, SC> { + pub fn new(config: &'c SC) -> Self { + Self { config } + } + + pub fn pcs(&self) -> &SC::Pcs { + self.config.pcs() + } + + /// Specialized prove for InteractiveAirs. + /// Handles trace generation of the permutation traces. + /// Assumes the main traces have been generated and committed already. + /// + /// Public values: for each AIR, a separate list of public values. + /// The prover can support global public values that are shared among all AIRs, + /// but we currently split public values per-AIR for modularity. + #[instrument(name = "MultiTraceStarkProveV2r::prove", level = "info", skip_all)] + pub fn prove<'a>( + &self, + challenger: &mut SC::Challenger, + mpk: &'a MultiStarkProvingKeyV2, + proof_input: ProofInput, + ) -> ProofV2 + where + SC::Pcs: Sync, + Domain: Send + Sync, + PcsProverData: Send + Sync, + Com: Send + Sync, + SC::Challenge: Send + Sync, + PcsProof: Send + Sync, + { + assert!(mpk.validate(&proof_input), "Invalid proof input"); + let pcs = self.config.pcs(); + + let (air_ids, air_inputs): (Vec<_>, Vec<_>) = multiunzip(proof_input.per_air.into_iter()); + let (airs, cached_mains_per_air, common_main_per_air, pvs_per_air): ( + Vec<_>, + Vec<_>, + Vec<_>, + Vec<_>, + ) = multiunzip(air_inputs.into_iter().map(|input| { + ( + input.air, + input.cached_mains, + input.common_main, + input.public_values, + ) + })); + + let num_air = air_ids.len(); + // Ignore unused AIRs. + let mpk = mpk.view(air_ids); + + // Challenger must observe public values + for pvs in &pvs_per_air { + challenger.observe_slice(pvs); + } + + let preprocessed_commits = mpk.vk_view().flattened_preprocessed_commits(); + challenger.observe_slice(&preprocessed_commits); + + // Commit all common main traces in a commitment. Traces inside are ordered by AIR id. + let (common_main_trace_views, common_main_prover_data) = { + let committer = TraceCommitter::::new(pcs); + let (trace_views, traces): (Vec<_>, Vec<_>) = common_main_per_air + .iter() + .filter_map(|cm| cm.as_ref()) + .map(|m| (m.as_view(), m.clone())) + .unzip(); + + (trace_views, committer.commit(traces)) + }; + + // Generate main trace commitments on the fly. + // Commitments order: + // - for each air: + // - for each cached main trace + // - 1 commitment + // - 1 commitment of all common main traces + let main_trace_commitments: Vec<_> = cached_mains_per_air + .iter() + .flatten() + .map(|cm| &cm.prover_data.commit) + .chain(iter::once(&common_main_prover_data.commit)) + .cloned() + .collect(); + challenger.observe_slice(&main_trace_commitments); + + // TODO: this is not needed if there are no interactions. Number of challenge rounds should be specified in proving key + // Generate permutation challenges + let challenges = mpk.vk_view().sample_challenges(challenger); + + let mut common_main_idx = 0; + let mut degree_per_air = Vec::with_capacity(num_air); + let mut main_views_per_air = Vec::with_capacity(num_air); + for (pk, cached_mains) in mpk.per_air.iter().zip(&cached_mains_per_air) { + let mut main_views: Vec<_> = cached_mains + .iter() + .map(|cm| cm.raw_data.as_view()) + .collect(); + if pk.vk.has_common_main() { + main_views.push(common_main_trace_views[common_main_idx].as_view()); + common_main_idx += 1; + } + degree_per_air.push(main_views[0].height()); + main_views_per_air.push(main_views); + } + let domain_per_air: Vec<_> = degree_per_air + .iter() + .map(|°ree| pcs.natural_domain_for_degree(degree)) + .collect(); + + let (cumulative_sum_per_air, perm_prover_data) = commit_permutation_traces( + pcs, + &mpk, + &challenges, + &main_views_per_air, + &pvs_per_air, + domain_per_air.clone(), + ); + + // Challenger needs to observe permutation_exposed_values (aka cumulative_sums) + for cumulative_sum in cumulative_sum_per_air.iter().flatten() { + challenger.observe_slice(cumulative_sum.as_base_slice()); + } + // Challenger observes commitment if exists + if let Some(data) = &perm_prover_data { + challenger.observe(data.commit.clone()); + } + // Generate `alpha` challenge + let alpha: SC::Challenge = challenger.sample_ext_element(); + tracing::debug!("alpha: {alpha:?}"); + + let quotient_data = commit_quotient_traces( + pcs, + &mpk, + alpha, + &challenges, + airs, + &pvs_per_air, + domain_per_air.clone(), + &cached_mains_per_air, + &common_main_prover_data, + &perm_prover_data, + cumulative_sum_per_air.clone(), + ); + + let (_, mut main_prover_data): (Vec<_>, Vec<_>) = cached_mains_per_air + .into_iter() + .flatten() + .map(|cm| (cm.raw_data, cm.prover_data)) + .unzip(); + main_prover_data.push(common_main_prover_data); + prove_raps_with_committed_traces( + pcs, + challenger, + mpk, + &main_prover_data, + perm_prover_data, + cumulative_sum_per_air, + quotient_data, + domain_per_air, + pvs_per_air, + ) + } +} +// +/// Proves general RAPs after all traces have been committed. +/// Soundness depends on `challenger` having already observed +/// public values, exposed values after challenge, and all +/// trace commitments. +/// +/// - `challenges`: for each trace challenge phase, the challenges sampled +/// +/// ## Assumptions +/// - `raps, trace_views, public_values` have same length and same order +/// - per challenge round, shared commitment for +/// all trace matrices, with matrices in increasing order of air index +#[allow(clippy::too_many_arguments)] +#[instrument(level = "info", skip_all)] +fn prove_raps_with_committed_traces<'a, SC: StarkGenericConfig>( + pcs: &SC::Pcs, + challenger: &mut SC::Challenger, + mpk: MultiStarkProvingKeyV2View, + main_prover_data: &[ProverTraceData], + perm_prover_data: Option>, + cumulative_sum_per_air: Vec>, + quotient_data: ProverQuotientData, + domain_per_air: Vec>, + public_values_per_air: Vec>>, +) -> ProofV2 +where + SC::Pcs: Sync, + Domain: Send + Sync, + PcsProverData: Send + Sync, + Com: Send + Sync, + SC::Challenge: Send + Sync, + PcsProof: Send + Sync, +{ + // Observe quotient commitment + challenger.observe(quotient_data.commit.clone()); + + let after_challenge_commitments: Vec<_> = perm_prover_data + .iter() + .map(|data| data.commit.clone()) + .collect(); + // Collect the commitments + let commitments = Commitments { + main_trace: main_prover_data + .iter() + .map(|data| data.commit.clone()) + .collect(), + after_challenge: after_challenge_commitments, + quotient: quotient_data.commit.clone(), + }; + + // Draw `zeta` challenge + let zeta: SC::Challenge = challenger.sample_ext_element(); + tracing::debug!("zeta: {zeta:?}"); + + // Open all polynomials at random points using pcs + let opener = OpeningProver::new(pcs, zeta); + let preprocessed_data: Vec<_> = mpk + .per_air + .iter() + .zip_eq(&domain_per_air) + .flat_map(|(pk, domain)| { + pk.preprocessed_data + .as_ref() + .map(|prover_data| (prover_data.data.as_ref(), *domain)) + }) + .collect(); + + let mut main_prover_data_idx = 0; + let mut main_data = Vec::with_capacity(main_prover_data.len()); + let mut common_main_domains = Vec::with_capacity(mpk.per_air.len()); + for (air_id, pk) in mpk.per_air.iter().enumerate() { + for _ in 0..pk.vk.num_cached_mains() { + main_data.push(( + main_prover_data[main_prover_data_idx].data.as_ref(), + vec![domain_per_air[air_id]], + )); + main_prover_data_idx += 1; + } + if pk.vk.has_common_main() { + common_main_domains.push(domain_per_air[air_id]); + } + } + main_data.push(( + main_prover_data[main_prover_data_idx].data.as_ref(), + common_main_domains, + )); + + // ASSUMING: per challenge round, shared commitment for all trace matrices, with matrices in increasing order of air index + let after_challenge_data = if let Some(perm_prover_data) = &perm_prover_data { + let mut domains = Vec::new(); + for (air_id, pk) in mpk.per_air.iter().enumerate() { + if pk.vk.has_interaction() { + domains.push(domain_per_air[air_id]); + } + } + vec![(perm_prover_data.data.as_ref(), domains)] + } else { + vec![] + }; + + let quotient_degrees = mpk + .per_air + .iter() + .map(|pk| pk.vk.quotient_degree) + .collect_vec(); + let opening = opener.open( + challenger, + preprocessed_data, + main_data, + after_challenge_data, + "ient_data.data, + "ient_degrees, + ); + + let degrees = domain_per_air + .iter() + .map(|domain| domain.size()) + .collect_vec(); + + let exposed_values_after_challenge = cumulative_sum_per_air + .into_iter() + .map(|csum| { + if let Some(csum) = csum { + vec![vec![csum]] + } else { + vec![] + } + }) + .collect_vec(); + + // tracing::info!("{}", trace_metrics(&pk.per_air, °rees)); + // #[cfg(feature = "bench-metrics")] + // trace_metrics(&pk.per_air, °rees).emit(); + + ProofV2 { + commitments, + opening, + per_air: izip!( + mpk.air_ids, + degrees, + exposed_values_after_challenge, + public_values_per_air + ) + .map( + |(air_id, degree, exposed_values, public_values)| AIRProofData { + air_id, + degree, + public_values, + exposed_values_after_challenge: exposed_values, + }, + ) + .collect(), + } +} diff --git a/stark-backend/src/prover/v2/trace.rs b/stark-backend/src/prover/v2/trace.rs new file mode 100644 index 0000000000..da48ac1139 --- /dev/null +++ b/stark-backend/src/prover/v2/trace.rs @@ -0,0 +1,208 @@ +use itertools::{izip, Itertools}; +use p3_matrix::{ + dense::{RowMajorMatrix, RowMajorMatrixView}, + Matrix, +}; +use p3_maybe_rayon::prelude::*; +use p3_uni_stark::{Domain, StarkGenericConfig, Val}; + +use crate::{ + commit::CommittedSingleMatrixView, + config::{Com, PcsProof, PcsProverData}, + interaction::trace::generate_permutation_trace, + keygen::v2::{types::StarkProvingKeyV2, view::MultiStarkProvingKeyV2View}, + prover::{ + commit_perm_traces, + quotient::{helper::QuotientVKDataHelper, ProverQuotientData, QuotientCommitter}, + trace::{ProverTraceData, SingleRapCommittedTraceView}, + v2::types::CommittedTraceData, + }, + rap::AnyRap, +}; + +#[allow(clippy::too_many_arguments)] +pub(super) fn commit_permutation_traces( + pcs: &SC::Pcs, + mpk: &MultiStarkProvingKeyV2View, + challenges: &[Vec], + main_views_per_air: &[Vec>>], + public_values_per_air: &[Vec>], + domain_per_air: Vec>, +) -> (Vec>, Option>) +where + SC::Pcs: Sync, + Domain: Send + Sync, + PcsProverData: Send + Sync, + Com: Send + Sync, + SC::Challenge: Send + Sync, + PcsProof: Send + Sync, +{ + let perm_trace_per_air = tracing::info_span!("generate permutation traces").in_scope(|| { + generate_permutation_trace_per_air( + challenges, + mpk, + main_views_per_air, + public_values_per_air, + ) + }); + let cumulative_sum_per_air = extract_cumulative_sums::(&perm_trace_per_air); + // Commit to permutation traces: this means only 1 challenge round right now + // One shared commit for all permutation traces + let perm_prover_data = tracing::info_span!("commit to permutation traces") + .in_scope(|| commit_perm_traces::(pcs, perm_trace_per_air, &domain_per_air)); + + (cumulative_sum_per_air, perm_prover_data) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn commit_quotient_traces<'a, SC: StarkGenericConfig>( + pcs: &SC::Pcs, + mpk: &MultiStarkProvingKeyV2View, + alpha: SC::Challenge, + challenges: &[Vec], + raps: Vec<&'a dyn AnyRap>, + public_values_per_air: &[Vec>], + domain_per_air: Vec>, + cached_mains_per_air: &'a [Vec>], + common_main_prover_data: &'a ProverTraceData, + perm_prover_data: &'a Option>, + cumulative_sum_per_air: Vec>, +) -> ProverQuotientData +where + SC::Pcs: Sync, + Domain: Send + Sync, + PcsProverData: Send + Sync, + Com: Send + Sync, + SC::Challenge: Send + Sync, + PcsProof: Send + Sync, +{ + let trace_views = create_trace_view_per_air( + domain_per_air, + cached_mains_per_air, + mpk, + cumulative_sum_per_air, + common_main_prover_data, + perm_prover_data, + ); + let quotient_committer = QuotientCommitter::new(pcs, challenges, alpha); + let qvks = mpk + .per_air + .iter() + .map(|pk| pk.get_quotient_vk_data()) + .collect_vec(); + let quotient_values = + quotient_committer.quotient_values(raps, &qvks, &trace_views, public_values_per_air); + // Commit to quotient polynomias. One shared commit for all quotient polynomials + quotient_committer.commit(quotient_values) +} + +/// Returns a list of optional tuples of (permutation trace,cumulative sum) for each AIR. +fn generate_permutation_trace_per_air( + challenges: &[Vec], + mpk: &MultiStarkProvingKeyV2View, + main_views_per_air: &[Vec>>], + public_values_per_air: &[Vec>], +) -> Vec>> +where + StarkProvingKeyV2: Send + Sync, +{ + // Generate permutation traces + let perm_challenges = challenges + .first() + .map(|c| [c[0], c[1]]) + .expect("Need 2 challenges"); // must have 2 challenges + + mpk.per_air + .par_iter() + .zip_eq(main_views_per_air.par_iter()) + .zip_eq(public_values_per_air.par_iter()) + .map(|((pk, main), public_values)| { + let interactions = &pk.vk.symbolic_constraints.interactions; + let preprocessed_trace = pk.preprocessed_data.as_ref().map(|d| d.trace.as_view()); + generate_permutation_trace( + interactions, + &preprocessed_trace, + main, + public_values, + Some(perm_challenges), + pk.vk.interaction_chunk_size, + ) + }) + .collect::>() +} + +fn extract_cumulative_sums( + perm_traces: &[Option>], +) -> Vec> { + perm_traces + .iter() + .map(|perm_trace| { + perm_trace.as_ref().map(|perm_trace| { + *perm_trace + .row_slice(perm_trace.height() - 1) + .last() + .unwrap() + }) + }) + .collect() +} + +fn create_trace_view_per_air<'a, SC: StarkGenericConfig>( + domain_per_air: Vec>, + cached_mains_per_air: &'a [Vec>], + mpk: &'a MultiStarkProvingKeyV2View, + cumulative_sum_per_air: Vec>, + common_main_prover_data: &'a ProverTraceData, + perm_prover_data: &'a Option>, +) -> Vec> { + let mut common_main_idx = 0; + let mut after_challenge_idx = 0; + izip!( + domain_per_air, + cached_mains_per_air, + &mpk.per_air, + cumulative_sum_per_air, + ) + .map(|(domain, cached_mains, pk, cumulative_sum)| { + // The AIR will be treated as the full RAP with virtual columns after this + let preprocessed = pk.preprocessed_data.as_ref().map(|p| { + // TODO: currently assuming each chip has it's own preprocessed commitment + CommittedSingleMatrixView::::new(p.data.as_ref(), 0) + }); + let mut partitioned_main: Vec<_> = cached_mains + .iter() + .map(|cm| CommittedSingleMatrixView::new(cm.prover_data.data.as_ref(), 0)) + .collect(); + if pk.vk.has_common_main() { + partitioned_main.push(CommittedSingleMatrixView::new( + common_main_prover_data.data.as_ref(), + common_main_idx, + )); + common_main_idx += 1; + } + + // There will be either 0 or 1 after_challenge traces + let after_challenge = if let Some(cumulative_sum) = cumulative_sum { + let matrix = CommittedSingleMatrixView::new( + perm_prover_data + .as_ref() + .expect("AIR uses interactions but no permutation trace commitment") + .data + .as_ref(), + after_challenge_idx, + ); + after_challenge_idx += 1; + let exposed_values = vec![cumulative_sum]; + vec![(matrix, exposed_values)] + } else { + Vec::new() + }; + SingleRapCommittedTraceView { + domain, + preprocessed, + partitioned_main, + after_challenge, + } + }) + .collect() +} diff --git a/stark-backend/src/prover/v2/types.rs b/stark-backend/src/prover/v2/types.rs new file mode 100644 index 0000000000..65db00f6ec --- /dev/null +++ b/stark-backend/src/prover/v2/types.rs @@ -0,0 +1,108 @@ +use itertools::Itertools; +use p3_matrix::dense::RowMajorMatrix; +use p3_uni_stark::{StarkGenericConfig, Val}; +use serde::{Deserialize, Serialize}; + +use crate::{ + keygen::v2::types::{MultiStarkProvingKeyV2, MultiStarkVerifyingKeyV2}, + prover::{opener::OpeningProof, trace::ProverTraceData, types::Commitments}, + rap::AnyRap, +}; + +/// The full proof for multiple RAPs where trace matrices are committed into +/// multiple commitments, where each commitment is multi-matrix. +/// +/// Includes the quotient commitments and FRI opening proofs for the constraints as well. +#[derive(Serialize, Deserialize)] +#[serde(bound = "")] +pub struct ProofV2 { + /// The PCS commitments + pub commitments: Commitments, + /// Opening proofs separated by partition, but this may change + pub opening: OpeningProof, + /// Proof data for each AIR + pub per_air: Vec>, +} + +#[derive(Serialize, Deserialize)] +#[serde(bound = "")] +pub struct AIRProofData { + pub air_id: usize, + /// height of trace matrix. + pub degree: usize, + /// For each challenge phase with trace, the values to expose to the verifier in that phase + pub exposed_values_after_challenge: Vec>, + // The public values to expose to the verifier + pub public_values: Vec>, +} + +/// Proof input +pub struct ProofInput<'a, SC: StarkGenericConfig> { + /// (AIR id, AIR input) + pub per_air: Vec<(usize, AIRProofInput<'a, SC>)>, +} + +pub struct CommittedTraceData { + pub raw_data: RowMajorMatrix>, + pub prover_data: ProverTraceData, +} + +/// Necessary input for proving a single AIR. +pub struct AIRProofInput<'a, SC: StarkGenericConfig> { + pub air: &'a dyn AnyRap, + /// Cached main trace matrices + pub cached_mains: Vec>, + /// Common main trace matrix + pub common_main: Option>>, + /// Public values + pub public_values: Vec>, +} + +pub trait AIRProofInputGenerator { + fn generate_air_proof_input<'a>() -> AIRProofInput<'a, SC>; +} + +impl ProofV2 { + pub fn get_air_ids(&self) -> Vec { + self.per_air.iter().map(|p| p.air_id).collect() + } + pub fn get_public_values(&self) -> Vec>> { + self.per_air + .iter() + .map(|p| p.public_values.clone()) + .collect() + } +} + +impl<'a, SC: StarkGenericConfig> ProofInput<'a, SC> { + pub fn sort(&mut self) { + self.per_air.sort_by_key(|p| p.0); + } +} + +impl MultiStarkVerifyingKeyV2 { + pub fn validate(&self, proof_input: &ProofInput) -> bool { + if !proof_input + .per_air + .iter() + .all(|input| input.0 < self.per_air.len()) + { + return false; + } + if !proof_input + .per_air + .iter() + .tuple_windows() + .all(|(a, b)| a.0 < b.0) + { + return false; + } + true + } +} + +impl MultiStarkProvingKeyV2 { + pub fn validate(&self, proof_input: &ProofInput) -> bool { + self.get_vk().validate(proof_input) + } +} diff --git a/stark-backend/src/verifier/constraints.rs b/stark-backend/src/verifier/constraints.rs index af9f658d7e..a6d3ea8d11 100644 --- a/stark-backend/src/verifier/constraints.rs +++ b/stark-backend/src/verifier/constraints.rs @@ -9,15 +9,17 @@ use tracing::instrument; use super::error::VerificationError; use crate::{ - air_builders::verifier::{GenericVerifierConstraintFolder, VerifierConstraintFolder}, - keygen::types::StarkVerifyingKey, + air_builders::{ + symbolic::symbolic_expression::SymbolicExpression, + verifier::{GenericVerifierConstraintFolder, VerifierConstraintFolder}, + }, prover::opener::AdjacentOpenedValues, }; #[allow(clippy::too_many_arguments)] #[instrument(skip_all)] pub fn verify_single_rap_constraints( - vk: &StarkVerifyingKey, + constraints: &[SymbolicExpression>], preprocessed_values: Option<&AdjacentOpenedValues>, partitioned_main_values: Vec<&AdjacentOpenedValues>, after_challenge_values: Vec<&AdjacentOpenedValues>, @@ -125,7 +127,7 @@ where exposed_values_after_challenge, _marker: PhantomData, }; - folder.eval_constraints(&vk.symbolic_constraints.constraints); + folder.eval_constraints(constraints); let folded_constraints = folder.accumulator; // Finally, check that diff --git a/stark-backend/src/verifier/mod.rs b/stark-backend/src/verifier/mod.rs index 23b3c2c917..244d539eed 100644 --- a/stark-backend/src/verifier/mod.rs +++ b/stark-backend/src/verifier/mod.rs @@ -7,6 +7,7 @@ use tracing::instrument; pub mod constraints; mod error; +pub mod v2; pub use error::*; @@ -272,8 +273,8 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkVerifier<'c, SC> { &opened_values.after_challenge[phase_idx][matrix_idx] }) .collect_vec(); - verify_single_rap_constraints( - vk, + verify_single_rap_constraints::( + &vk.symbolic_constraints.constraints, preprocessed_values, partitioned_main_values, after_challenge_values, diff --git a/stark-backend/src/verifier/v2/mod.rs b/stark-backend/src/verifier/v2/mod.rs new file mode 100644 index 0000000000..374163a36c --- /dev/null +++ b/stark-backend/src/verifier/v2/mod.rs @@ -0,0 +1,311 @@ +use itertools::{izip, Itertools}; +use p3_challenger::{CanObserve, FieldChallenger}; +use p3_commit::{Pcs, PolynomialSpace}; +use p3_field::{AbstractExtensionField, AbstractField}; +use p3_uni_stark::{Domain, StarkGenericConfig}; +use tracing::instrument; + +use crate::{ + keygen::v2::{types::MultiStarkVerifyingKeyV2, view::MultiStarkVerifyingKeyV2View}, + prover::{opener::AdjacentOpenedValues, v2::types::ProofV2}, + verifier::{constraints::verify_single_rap_constraints, VerificationError}, +}; + +/// Verifies a partitioned proof of multi-matrix AIRs. +pub struct MultiTraceStarkVerifierV2<'c, SC: StarkGenericConfig> { + config: &'c SC, +} + +impl<'c, SC: StarkGenericConfig> MultiTraceStarkVerifierV2<'c, SC> { + pub fn new(config: &'c SC) -> Self { + Self { config } + } + /// Verify collection of InteractiveAIRs and check the permutation + /// cumulative sum is equal to zero across all AIRs. + #[instrument(name = "MultiTraceStarkVerifier::verify", level = "debug", skip_all)] + pub fn verify( + &self, + challenger: &mut SC::Challenger, + mvk: &MultiStarkVerifyingKeyV2, + proof: &ProofV2, + ) -> Result<(), VerificationError> { + let mvk = mvk.view(&proof.get_air_ids()); + let cumulative_sums = proof + .per_air + .iter() + .map(|p| { + assert!( + p.exposed_values_after_challenge.len() <= 1, + "Verifier does not support more than 1 challenge phase" + ); + p.exposed_values_after_challenge.first().map(|values| { + assert_eq!( + values.len(), + 1, + "Only exposed value should be cumulative sum" + ); + values[0] + }) + }) + .collect_vec(); + + self.verify_raps(challenger, &mvk, proof)?; + + // Check cumulative sum + let sum: SC::Challenge = cumulative_sums + .into_iter() + .map(|c| c.unwrap_or(SC::Challenge::zero())) + .sum(); + if sum != SC::Challenge::zero() { + return Err(VerificationError::NonZeroCumulativeSum); + } + Ok(()) + } + + /// Verify general RAPs without checking any relations (e.g., cumulative sum) between exposed values of different RAPs. + /// + /// Public values is a global list shared across all AIRs. + /// + /// - `num_challenges_to_sample[i]` is the number of challenges to sample in the trace challenge phase corresponding to `proof.commitments.after_challenge[i]`. This must have length equal + /// to `proof.commitments.after_challenge`. + #[instrument(level = "debug", skip_all)] + pub fn verify_raps( + &self, + challenger: &mut SC::Challenger, + mvk: &MultiStarkVerifyingKeyV2View, + proof: &ProofV2, + ) -> Result<(), VerificationError> { + let public_values = proof.get_public_values(); + // Challenger must observe public values + for pis in &public_values { + challenger.observe_slice(pis); + } + + // TODO: valid shape check from verifying key + for preprocessed_commit in mvk.flattened_preprocessed_commits() { + challenger.observe(preprocessed_commit); + } + + // Observe main trace commitments + challenger.observe_slice(&proof.commitments.main_trace); + + let mut challenges = Vec::new(); + for (phase_idx, (&num_to_sample, commit)) in mvk + .num_challenges_to_sample() + .iter() + .zip_eq(&proof.commitments.after_challenge) + .enumerate() + { + // Sample challenges needed in this phase + challenges.push( + (0..num_to_sample) + .map(|_| challenger.sample_ext_element::()) + .collect_vec(), + ); + // For each RAP, the exposed values in current phase + for air_proof in &proof.per_air { + let exposed_values = air_proof.exposed_values_after_challenge.get(phase_idx); + if let Some(values) = exposed_values { + // Observe exposed values (in ext field) + for value in values { + challenger.observe_slice(value.as_base_slice()); + } + } + } + // Observe single commitment to all trace matrices in this phase + challenger.observe(commit.clone()); + } + + // Draw `alpha` challenge + let alpha: SC::Challenge = challenger.sample_ext_element(); + tracing::debug!("alpha: {alpha:?}"); + + // Observe quotient commitments + challenger.observe(proof.commitments.quotient.clone()); + + // Draw `zeta` challenge + let zeta: SC::Challenge = challenger.sample_ext_element(); + tracing::debug!("zeta: {zeta:?}"); + + let pcs = self.config.pcs(); + // Build domains + let (domains, quotient_chunks_domains): (Vec<_>, Vec>) = mvk + .per_air + .iter() + .zip_eq(&proof.per_air) + .map(|(vk, air_proof)| { + let degree = air_proof.degree; + let quotient_degree = vk.quotient_degree; + let domain = pcs.natural_domain_for_degree(degree); + let quotient_domain = domain.create_disjoint_domain(degree * quotient_degree); + let qc_domains = quotient_domain.split_domains(quotient_degree); + (domain, qc_domains) + }) + .unzip(); + // Verify all opening proofs + let opened_values = &proof.opening.values; + let trace_domain_and_openings = + |domain: Domain, + zeta: SC::Challenge, + values: &AdjacentOpenedValues| { + ( + domain, + vec![ + (zeta, values.local.clone()), + (domain.next_point(zeta).unwrap(), values.next.clone()), + ], + ) + }; + // Build the opening rounds + // 1. First the preprocessed trace openings + // Assumption: each AIR with preprocessed trace has its own commitment and opening values + let mut rounds: Vec<_> = mvk + .preprocessed_commits() + .into_iter() + .zip_eq(&domains) + .flat_map(|(commit, domain)| commit.map(|commit| (commit, *domain))) + .zip_eq(&opened_values.preprocessed) + .map(|((commit, domain), values)| { + let domain_and_openings = trace_domain_and_openings(domain, zeta, values); + (commit, vec![domain_and_openings]) + }) + .collect(); + + // 2. Then the main trace openings + let mut air_idx = 0; + let num_main_commits = opened_values.main.len(); + assert_eq!(num_main_commits, proof.commitments.main_trace.len()); + izip!(&opened_values.main, &proof.commitments.main_trace) + .enumerate() + .for_each(|(commit_idx, (values_per_mat, commit))| { + // All commits except the last one are cached main traces. + let domains_and_openings = if commit_idx + 1 < num_main_commits { + assert_eq!( + values_per_mat.len(), + 1, + "Cached main trace should have only 1 matrix" + ); + let domain = domains[air_idx]; + air_idx += 1; + vec![trace_domain_and_openings(domain, zeta, &values_per_mat[0])] + } else { + // Each matrix corresponds to an AIR with a common main trace. + mvk.per_air + .iter() + .zip_eq(&domains) + .flat_map(|(vk, domain)| { + if vk.has_common_main() { + Some(*domain) + } else { + None + } + }) + .zip_eq(values_per_mat) + .map(|(domain, values)| trace_domain_and_openings(domain, zeta, values)) + .collect_vec() + }; + rounds.push((commit.clone(), domains_and_openings)); + }); + + // 3. Then after_challenge trace openings, at most 1 phase for now. + // All AIRs with interactions should an after challenge trace. + let after_challenge_domain_per_air = mvk + .per_air + .iter() + .zip_eq(&domains) + .filter_map(|(vk, domain)| { + if vk.has_interaction() { + Some(*domain) + } else { + None + } + }) + .collect_vec(); + if after_challenge_domain_per_air.is_empty() { + assert_eq!(proof.commitments.after_challenge.len(), 0); + assert_eq!(opened_values.after_challenge.len(), 0); + } else { + let after_challenge_commit = proof.commitments.after_challenge[0].clone(); + let domains_and_openings = after_challenge_domain_per_air + .into_iter() + .zip_eq(&opened_values.after_challenge[0]) + .map(|(domain, values)| trace_domain_and_openings(domain, zeta, values)) + .collect_vec(); + rounds.push((after_challenge_commit, domains_and_openings)); + } + + let quotient_domains_and_openings = opened_values + .quotient + .iter() + .zip_eq("ient_chunks_domains) + .flat_map(|(chunk, quotient_chunks_domains_per_air)| { + chunk + .iter() + .zip_eq(quotient_chunks_domains_per_air) + .map(|(values, &domain)| (domain, vec![(zeta, values.clone())])) + }) + .collect_vec(); + rounds.push(( + proof.commitments.quotient.clone(), + quotient_domains_and_openings, + )); + + pcs.verify(rounds, &proof.opening.proof, challenger) + .map_err(|e| VerificationError::InvalidOpeningArgument(format!("{:?}", e)))?; + + let mut preprocessed_idx = 0usize; // preprocessed commit idx + let num_phases = mvk.num_phases(); + let mut after_challenge_idx = vec![0usize; num_phases]; + let mut cached_main_commit_idx = 0; + let mut common_main_matrix_idx = 0; + + // Verify each RAP's constraints + for (domain, qc_domains, quotient_chunks, vk, air_proof) in izip!( + domains, + quotient_chunks_domains, + &opened_values.quotient, + &mvk.per_air, + &proof.per_air + ) { + let preprocessed_values = vk.preprocessed_data.as_ref().map(|_| { + let values = &opened_values.preprocessed[preprocessed_idx]; + preprocessed_idx += 1; + values + }); + let mut partitioned_main_values = Vec::with_capacity(vk.num_cached_mains()); + for _ in 0..vk.num_cached_mains() { + partitioned_main_values.push(&opened_values.main[cached_main_commit_idx][0]); + cached_main_commit_idx += 1; + } + if vk.has_common_main() { + partitioned_main_values + .push(&opened_values.main.last().unwrap()[common_main_matrix_idx]); + common_main_matrix_idx += 1; + } + // loop through challenge phases of this single RAP + let after_challenge_values = (0..num_phases) + .map(|phase_idx| { + let matrix_idx = after_challenge_idx[phase_idx]; + after_challenge_idx[phase_idx] += 1; + &opened_values.after_challenge[phase_idx][matrix_idx] + }) + .collect_vec(); + verify_single_rap_constraints::( + &vk.symbolic_constraints.constraints, + preprocessed_values, + partitioned_main_values, + after_challenge_values, + quotient_chunks, + domain, + &qc_domains, + zeta, + alpha, + &challenges, + &air_proof.public_values, + &air_proof.exposed_values_after_challenge, + )?; + } + + Ok(()) + } +} diff --git a/stark-backend/tests/cached_lookup/mod.rs b/stark-backend/tests/cached_lookup/mod.rs index b835aa73d9..0ddbb4bc8f 100644 --- a/stark-backend/tests/cached_lookup/mod.rs +++ b/stark-backend/tests/cached_lookup/mod.rs @@ -1,14 +1,22 @@ use std::iter; use afs_stark_backend::{ - keygen::MultiStarkKeygenBuilder, - prover::{trace::TraceCommitmentBuilder, MultiTraceStarkProver, USE_DEBUG_BUILDER}, - verifier::{MultiTraceStarkVerifier, VerificationError}, + keygen::v2::MultiStarkKeygenBuilderV2, + prover::{ + trace::TraceCommitter, + v2::{ + types::{AIRProofInput, CommittedTraceData, ProofInput}, + MultiTraceStarkProverV2, + }, + USE_DEBUG_BUILDER, + }, + verifier::{v2::MultiTraceStarkVerifierV2, VerificationError}, }; use ax_sdk::interaction::dummy_interaction_air::DummyInteractionAir; use p3_baby_bear::BabyBear; use p3_field::AbstractField; use p3_matrix::dense::RowMajorMatrix; +use p3_uni_stark::StarkGenericConfig; use p3_util::log2_ceil_usize; use crate::config; @@ -65,42 +73,51 @@ pub fn prove_and_verify_indexless_lookups( .collect(), receiver_air.field_width(), ); - - let mut keygen_builder = MultiStarkKeygenBuilder::new(&config); - // Cached table pointer: - let recv_fields_ptr = keygen_builder.add_cached_main_matrix(receiver_air.field_width()); - // Everything else together - let recv_count_ptr = keygen_builder.add_main_matrix(1); - keygen_builder.add_partitioned_air(&receiver_air, vec![recv_count_ptr, recv_fields_ptr]); - // Auto-adds sender matrix - keygen_builder.add_air(&sender_air); - let pk = keygen_builder.generate_pk(); - let vk = pk.vk(); - - let prover = MultiTraceStarkProver::new(&config); - // Must add trace matrices in the same order as above - let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - // Receiver fields table is cached - let cached_trace_data = trace_builder - .committer - .commit(vec![recv_fields_trace.clone()]); - trace_builder.load_cached_trace(recv_fields_trace, cached_trace_data); - // Load x normally - trace_builder.load_trace(recv_count_trace); - trace_builder.load_trace(sender_trace); - trace_builder.commit_current(); - - let main_trace_data = trace_builder.view(&vk, vec![&receiver_air, &sender_air]); - let pis = vec![vec![]; 2]; - - let mut challenger = config::baby_bear_poseidon2::Challenger::new(perm.clone()); - let proof = prover.prove(&mut challenger, &pk, main_trace_data, &pis); - - // Verify the proof: - // Start from clean challenger - let mut challenger = config::baby_bear_poseidon2::Challenger::new(perm.clone()); - let verifier = MultiTraceStarkVerifier::new(prover.config); - verifier.verify(&mut challenger, &vk, &proof) + { + let mut keygen_builder = MultiStarkKeygenBuilderV2::new(&config); + let receiver_air_id = keygen_builder.add_air(&receiver_air); + // Auto-adds sender matrix + let sender_air_id = keygen_builder.add_air(&sender_air); + let pk = keygen_builder.generate_pk(); + let committer = TraceCommitter::new(config.pcs()); + let cached_trace_data = committer.commit(vec![recv_fields_trace.clone()]); + let proof_input = ProofInput { + per_air: vec![ + ( + receiver_air_id, + AIRProofInput { + air: &receiver_air, + cached_mains: vec![CommittedTraceData { + raw_data: recv_fields_trace, + prover_data: cached_trace_data, + }], + common_main: Some(recv_count_trace), + public_values: vec![], + }, + ), + ( + sender_air_id, + AIRProofInput { + air: &sender_air, + cached_mains: vec![], + common_main: Some(sender_trace), + public_values: vec![], + }, + ), + ], + }; + + let prover = MultiTraceStarkProverV2::new(&config); + + let mut challenger = config::baby_bear_poseidon2::Challenger::new(perm.clone()); + let proof = prover.prove(&mut challenger, &pk, proof_input); + + // Verify the proof: + // Start from clean challenger + let mut challenger = config::baby_bear_poseidon2::Challenger::new(perm.clone()); + let verifier = MultiTraceStarkVerifierV2::new(prover.config); + verifier.verify(&mut challenger, &pk.get_vk(), &proof) + } } /// tests for cached_lookup