From e8197db3e0fbcfd8bb0052be3f1a4c2f9a86b8b0 Mon Sep 17 00:00:00 2001 From: Ming Date: Wed, 5 Mar 2025 20:05:52 +0700 Subject: [PATCH] zkvm proof size statistics data (#852) This PR implement formatter on `ZKVMProof` to break down proof size, which helps to identify and prioritize recursion tasks > NOTE: it's based on basefold mpcs ```bash e2e proof stat: overall_size 19.07mb. opcode mpcs commitment 0% opcode mpcs opening 40% opcode tower proof 0% opcode main sumcheck proof 0% table mpcs commitment 0% table mpcs opening 32% table mpcs fixed opening 25% table tower proof 0% table same r sumcheck proof 0% ``` --- ceno_zkvm/benches/fibonacci.rs | 5 +- ceno_zkvm/src/bin/e2e.rs | 5 +- ceno_zkvm/src/scheme.rs | 126 ++++++++++++++++++++++++++++++++- ceno_zkvm/src/scheme/tests.rs | 6 +- 4 files changed, 129 insertions(+), 13 deletions(-) diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 7f725921e..f2bc65b00 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -51,7 +51,7 @@ fn fibonacci_prove(c: &mut Criterion) { .0 .expect("PrepSanityCheck do not provide proof and verifier"); - let serialize_size = bincode::serialize(&proof).unwrap().len(); + println!("e2e proof {}", proof); let stat_recorder = StatisticRecorder::default(); let transcript = BasicTranscriptWithStat::new(&stat_recorder, b"riscv"); assert!( @@ -61,9 +61,8 @@ fn fibonacci_prove(c: &mut Criterion) { ); println!(); println!( - "max_steps = {}, proof size = {}, hashes count = {}", + "max_steps = {}, hashes count = {}", max_steps, - serialize_size, stat_recorder.into_inner().field_appended_num ); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 5b0bb18b8..18d6268d6 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -146,13 +146,12 @@ fn main() { let (mut zkvm_proof, verifier) = state.expect("PrepSanityCheck should yield state."); // do statistics - let serialize_size = bincode::serialize(&zkvm_proof).unwrap().len(); let stat_recorder = StatisticRecorder::default(); let transcript = TranscriptWithStat::new(&stat_recorder, b"riscv"); verifier.verify_proof(zkvm_proof.clone(), transcript).ok(); + println!("e2e proof stat: {}", zkvm_proof); println!( - "e2e proof stat: proof size = {}, hashes count = {}", - serialize_size, + "hashes count = {}", stat_recorder.into_inner().field_appended_num ); diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index dded3ebd7..0c5107e45 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -3,7 +3,11 @@ use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; use p3_field::PrimeCharacteristicRing; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use std::{collections::BTreeMap, fmt::Debug}; +use std::{ + collections::BTreeMap, + fmt::{self, Debug}, + ops::Div, +}; use sumcheck::structs::IOPProverMessage; use crate::structs::TowerProofs; @@ -160,10 +164,126 @@ impl> ZKVMProof { pub fn update_pi_eval(&mut self, idx: usize, v: E) { self.pi_evals[idx] = v; } -} -impl> ZKVMProof { pub fn num_circuits(&self) -> usize { self.opcode_proofs.len() + self.table_proofs.len() } } + +impl + Serialize> fmt::Display + for ZKVMProof +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // break down zkvm proof size + // opcode circuit mpcs size + let mpcs_opcode_commitment = self + .opcode_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.wits_commit)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + let mpcs_opcode_opening = self + .opcode_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.wits_opening_proof)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + // opcode circuit for tower proof size + let tower_proof_opcode = self + .opcode_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.tower_proof)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + // opcode circuit main sumcheck + let main_sumcheck_opcode = self + .opcode_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.main_sel_sumcheck_proofs)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + // table circuit mpcs size + let mpcs_table_commitment = self + .table_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.wits_commit)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + let mpcs_table_opening = self + .table_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.wits_opening_proof)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + let mpcs_table_fixed_opening = self + .table_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.fixed_opening_proof)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + // table circuit for tower proof size + let tower_proof_table = self + .table_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.tower_proof)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + // table circuit same r sumcheck + let same_r_sumcheck_table = self + .table_proofs + .iter() + .map(|(_, (_, proof))| bincode::serialized_size(&proof.same_r_sumcheck_proofs)) + .collect::, _>>() + .expect("serialization error") + .iter() + .sum::(); + + // overall size + let overall_size = bincode::serialized_size(&self).expect("serialization error"); + + // let mpcs_size = bincode::serialized_size(&proof.).unwrap().len(); + write!( + f, + "overall_size {:.2}mb. \n\ + opcode mpcs commitment {:?}% \n\ + opcode mpcs opening {:?}% \n\ + opcode tower proof {:?}% \n\ + opcode main sumcheck proof {:?}% \n\ + table mpcs commitment {:?}% \n\ + table mpcs opening {:?}% \n\ + table mpcs fixed opening {:?}% \n\ + table tower proof {:?}% \n\ + table same r sumcheck proof {:?}%", + byte_to_mb(overall_size), + (mpcs_opcode_commitment * 100).div(overall_size), + (mpcs_opcode_opening * 100).div(overall_size), + (tower_proof_opcode * 100).div(overall_size), + (main_sumcheck_opcode * 100).div(overall_size), + (mpcs_table_commitment * 100).div(overall_size), + (mpcs_table_opening * 100).div(overall_size), + (mpcs_table_fixed_opening * 100).div(overall_size), + (tower_proof_table * 100).div(overall_size), + (same_r_sumcheck_table * 100).div(overall_size), + ) + } +} + +fn byte_to_mb(byte_size: u64) -> f64 { + byte_size as f64 / (1024.0 * 1024.0) +} diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 6bec93d22..6267b0540 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -292,8 +292,7 @@ fn test_single_add_instance_e2e() { .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); - let encoded_bin = bincode::serialize(&zkvm_proof).unwrap(); - + println!("encoded zkvm proof {}", &zkvm_proof,); let stat_recorder = StatisticRecorder::default(); { let transcript = BasicTranscriptWithStat::new(&stat_recorder, b"riscv"); @@ -304,8 +303,7 @@ fn test_single_add_instance_e2e() { ); } println!( - "encoded zkvm proof size: {}, hash_num: {}", - encoded_bin.len(), + "hash_num: {}", stat_recorder.into_inner().field_appended_num ); }