diff --git a/Cargo.lock b/Cargo.lock index a025f7e29..86ce54b3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -582,6 +582,7 @@ dependencies = [ "num-traits", "p3-field", "p3-goldilocks", + "p3-matrix", "p3-mds", "parse-size", "paste", @@ -604,6 +605,7 @@ dependencies = [ "tracing-forest", "tracing-subscriber 0.3.19", "transcript", + "witness", ] [[package]] @@ -1525,6 +1527,7 @@ dependencies = [ "num-integer", "p3-field", "p3-goldilocks", + "p3-matrix", "p3-mds", "p3-symmetric", "plonky2", @@ -1535,6 +1538,7 @@ dependencies = [ "serde", "transcript", "whir", + "witness", "zeroize", ] @@ -3302,6 +3306,20 @@ dependencies = [ "bitflags", ] +[[package]] +name = "witness" +version = "0.1.0" +dependencies = [ + "ff_ext", + "multilinear_extensions", + "p3-field", + "p3-goldilocks", + "p3-matrix", + "rand", + "rayon", + "serde", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 615ed199f..1bc5148d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "sumcheck", "transcript", "whir", + "witness", ] resolver = "2" @@ -37,6 +38,7 @@ num-traits = "0.2" p3-challenger = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } p3-field = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } p3-goldilocks = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } +p3-matrix = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } p3-mds = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } p3-poseidon = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } p3-poseidon2 = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index f96dccb53..8f54b6942 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -23,11 +23,13 @@ mpcs = { path = "../mpcs" } multilinear_extensions = { version = "0", path = "../multilinear_extensions" } sumcheck = { version = "0", path = "../sumcheck" } transcript = { path = "../transcript" } +witness = { path = "../witness" } itertools.workspace = true num-traits.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true +p3-matrix.workspace = true p3-mds.workspace = true paste.workspace = true poseidon.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 4dd02b5f9..b0e6cd2b0 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -1,6 +1,5 @@ use std::time::Duration; -use ark_std::test_rng; use ceno_zkvm::{ self, instructions::{Instruction, riscv::arith::AddInstruction}, @@ -10,12 +9,13 @@ use ceno_zkvm::{ use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; -use ff_ext::{FromUniformBytes, GoldilocksExt2}; +use ff_ext::GoldilocksExt2; use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme}; -use multilinear_extensions::mle::IntoMLE; -use p3_goldilocks::Goldilocks; + +use rand::rngs::OsRng; use transcript::{BasicTranscript, Transcript}; +use witness::RowMajorMatrix; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -74,22 +74,16 @@ fn bench_add(c: &mut Criterion) { let mut time = Duration::new(0, 0); for _ in 0..iters { // generate mock witness - let mut rng = test_rng(); let num_instances = 1 << instance_num_vars; - let wits_in = (0..num_witin as usize) - .map(|_| { - (0..num_instances) - .map(|_| Goldilocks::random(&mut rng)) - .collect::>() - .into_mle() - }) - .collect_vec(); + let rmm = + RowMajorMatrix::rand(&mut OsRng, num_instances, num_witin as usize); + let polys = rmm.to_mles(); let instant = std::time::Instant::now(); let num_instances = 1 << instance_num_vars; let mut transcript = BasicTranscript::new(b"riscv"); let commit = - Pcs::batch_commit_and_write(&prover.pk.pp, &wits_in, &mut transcript) + Pcs::batch_commit_and_write(&prover.pk.pp, rmm, &mut transcript) .unwrap(); let challenges = [ transcript.read_challenge().elements, @@ -101,7 +95,7 @@ fn bench_add(c: &mut Criterion) { "ADD", &prover.pk.pp, &circuit_pk, - wits_in.into_iter().map(|mle| mle.into()).collect_vec(), + polys.into_iter().map(|mle| mle.into()).collect_vec(), commit, &[], num_instances, diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index b1fd2fb48..380582fe6 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -10,9 +10,10 @@ use crate::{ error::ZKVMError, expression::{Expression, Fixed, Instance, StructuralWitIn, WitIn}, structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId}, - witness::RowMajorMatrix, }; + use p3_field::PrimeCharacteristicRing; +use witness::RowMajorMatrix; /// namespace used for annotation, preserve meta info during circuit construction #[derive(Clone, Debug, Default, serde::Serialize)] @@ -180,15 +181,13 @@ impl ConstraintSystem { fixed_traces: Option>, ) -> ProvingKey { // transpose from row-major to column-major - let fixed_traces = fixed_traces.map(RowMajorMatrix::into_mles); + let fixed_traces_polys = fixed_traces.as_ref().map(|rmm| rmm.to_mles()); - let fixed_commit_wd = fixed_traces - .as_ref() - .map(|traces| PCS::batch_commit(pp, traces).unwrap()); + let fixed_commit_wd = fixed_traces.map(|traces| PCS::batch_commit(pp, traces).unwrap()); let fixed_commit = fixed_commit_wd.as_ref().map(PCS::get_pure_commitment); ProvingKey { - fixed_traces, + fixed_traces: fixed_traces_polys, fixed_commit_wd, vk: VerifyingKey { cs: self, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index b72401aec..fc48fb17a 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -5,28 +5,12 @@ use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, }; -use std::sync::Arc; -use crate::{ - circuit_builder::CircuitBuilder, - error::ZKVMError, - witness::{LkMultiplicity, RowMajorMatrix}, -}; +use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::LkMultiplicity}; -pub mod riscv; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; -#[derive(Clone)] -pub enum InstancePaddingStrategy { - // Pads with default values of underlying type - // Usually zero, but check carefully - Default, - // Pads by repeating last row - RepeatLast, - // Custom strategy consists of a closure - // `pad(i, j) = padding value for cell at row i, column j` - // pad should be able to cross thread boundaries - Custom(Arc u64 + Send + Sync>), -} +pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index fb388c94a..066752b74 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -598,7 +598,7 @@ mod test { MockProver::assert_with_expected_errors( &cb, &raw_witin - .into_mles() + .to_mles() .into_iter() .map(|v| v.into()) .collect_vec(), diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index ff78ea5d7..0cb1dccf4 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -486,15 +486,15 @@ impl MemAddr { mod test { use ff_ext::GoldilocksExt2 as E; use itertools::Itertools; + use witness::{InstancePaddingStrategy, RowMajorMatrix}; use p3_goldilocks::Goldilocks as F; use crate::{ ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - instructions::InstancePaddingStrategy, scheme::mock_prover::MockProver, - witness::{LkMultiplicity, RowMajorMatrix}, + witness::LkMultiplicity, }; use super::MemAddr; @@ -562,7 +562,7 @@ mod test { MockProver::assert_with_expected_errors( &cb, &raw_witin - .into_mles() + .to_mles() .into_iter() .map(|v| v.into()) .collect_vec(), diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 8ef688b1d..62b92b0ed 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -13,7 +13,7 @@ use crate::{ AndTable, LtuTable, OpsTable, OrTable, PowTable, ProgramTableCircuit, RangeTable, TableCircuit, U5Table, U8Table, U14Table, U16Table, XorTable, }, - witness::{LkMultiplicity, LkMultiplicityRaw, RowMajorMatrix}, + witness::{LkMultiplicity, LkMultiplicityRaw}, }; use ark_std::test_rng; use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; @@ -36,6 +36,7 @@ use std::{ }; use strum::IntoEnumIterator; use tiny_keccak::{Hasher, Keccak}; +use witness::RowMajorMatrix; const MAX_CONSTRAINT_DEGREE: usize = 2; const MOCK_PROGRAM_SIZE: usize = 32; @@ -747,7 +748,7 @@ Hints: lkm: Option, ) { let wits_in = raw_witin - .into_mles() + .to_mles() .into_iter() .map(|v| v.into()) .collect_vec(); @@ -805,13 +806,20 @@ Hints: // Process all circuits. for (circuit_name, cs) in &cs.circuit_css { + let empty_rmm = RowMajorMatrix::empty(); let is_opcode = cs.lk_table_expressions.is_empty() && cs.r_table_expressions.is_empty() && cs.w_table_expressions.is_empty(); - let witness = if is_opcode { - witnesses - .get_opcode_witness(circuit_name) - .unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name)) + let [witness, _] = if is_opcode { + &[ + witnesses + .get_opcode_witness(circuit_name) + .cloned() + .unwrap_or_else(|| { + panic!("witness for {} should not be None", circuit_name) + }), + empty_rmm, + ] } else { witnesses .get_table_witness(circuit_name) @@ -827,7 +835,7 @@ Hints: continue; } let mut witness = witness - .into_mles() + .to_mles() .into_iter() .map(|w| w.into()) .collect_vec(); @@ -837,11 +845,7 @@ Hints: .remove(circuit_name) .and_then(|fixed| fixed) .map_or(vec![], |fixed| { - fixed - .into_mles() - .into_iter() - .map(|f| f.into()) - .collect_vec() + fixed.to_mles().into_iter().map(|f| f.into()).collect_vec() }); if is_opcode { tracing::info!( @@ -1249,13 +1253,13 @@ mod tests { error::ZKVMError, expression::{ToExpr, WitIn}, gadgets::{AssertLtConfig, IsLtConfig}, - instructions::InstancePaddingStrategy, set_val, - witness::{LkMultiplicity, RowMajorMatrix}, + witness::LkMultiplicity, }; use ff_ext::{FieldInto, GoldilocksExt2}; use multilinear_extensions::mle::IntoMLE; use p3_goldilocks::Goldilocks; + use witness::InstancePaddingStrategy; #[derive(Debug)] struct AssertZeroCircuit { diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 97ae87e6e..f7406cdec 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -18,6 +18,7 @@ use sumcheck::{ structs::{IOPProverMessage, IOPProverState}, }; use transcript::{ForkableTranscript, Transcript}; +use witness::{RowMajorMatrix, next_pow2_instance_padding}; use crate::{ error::ZKVMError, @@ -32,7 +33,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, + utils::{get_challenge_pows, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, }; @@ -95,29 +96,28 @@ impl> ZKVMProver { let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name - for (circuit_name, witness) in witnesses.into_iter_sorted() { - let num_instances = witness.num_instances(); + for (circuit_name, mut rmm) in witnesses.into_iter_sorted() { + let witness_rmm = rmm.remove(0); + let structural_witness_rmm = if !rmm.is_empty() { + rmm.remove(0) + } else { + RowMajorMatrix::empty() + }; + let num_instances = witness_rmm.num_instances(); let span = entered_span!( "commit to iteration", circuit_name = circuit_name, profiling_2 = true ); - let num_witin = self - .pk - .circuit_pks - .get(&circuit_name) - .unwrap() - .get_cs() - .num_witin; let (witness, structural_witness) = match num_instances { 0 => (vec![], vec![]), _ => { - let mut witness = witness.into_mles(); - let structural_witness = witness.split_off(num_witin as usize); + let witness = witness_rmm.to_mles(); + let structural_witness = structural_witness_rmm.to_mles(); commitments.insert( circuit_name.clone(), - PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) + PCS::batch_commit_and_write(&self.pk.pp, witness_rmm, &mut transcript) .map_err(ZKVMError::PCSError)?, ); @@ -162,7 +162,7 @@ impl> ZKVMProver { { let (witness, num_instances) = wits .remove(circuit_name) - .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; + .ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string()))?; if witness.is_empty() { continue; } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 6bec93d22..e2afd5469 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -124,14 +124,10 @@ fn test_rw_lk_expression_combination() { // get proof let prover = ZKVMProver::new(pk); let mut transcript = BasicTranscript::new(b"test"); - let wits_in = zkvm_witness - .into_iter_sorted() - .next() - .unwrap() - .1 - .into_mles(); + let rmm = zkvm_witness.into_iter_sorted().next().unwrap().1.remove(0); + let wits_in = rmm.to_mles(); // commit to main traces - let commit = Pcs::batch_commit_and_write(&prover.pk.pp, &wits_in, &mut transcript).unwrap(); + let commit = Pcs::batch_commit_and_write(&prover.pk.pp, rmm, &mut transcript).unwrap(); let wits_in = wits_in.into_iter().map(|v| v.into()).collect_vec(); let prover_challenges = [ transcript.read_challenge().elements, diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 9c0478006..5106cad28 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -17,10 +17,9 @@ use rayon::{ }, prelude::ParallelSliceMut, }; +use witness::next_pow2_instance_padding; -use crate::{ - expression::Expression, scheme::constants::MIN_PAR_SIZE, utils::next_pow2_instance_padding, -}; +use crate::{expression::Expression, scheme::constants::MIN_PAR_SIZE}; /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector /// e.g input [[1,2],[3,4],[5,6],[7,8]], num_limbs=2,log2_per_instance_size=3 diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index bc104156e..fea75a50c 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -12,6 +12,7 @@ use multilinear_extensions::{ }; use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::{ForkableTranscript, Transcript}; +use witness::next_pow2_instance_padding; use crate::{ error::ZKVMError, @@ -22,10 +23,7 @@ use crate::{ utils::eval_by_expr_with_instance, }, structs::{Point, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, - utils::{ - eq_eval_less_or_equal_than, eval_wellform_address_vec, get_challenge_pows, - next_pow2_instance_padding, - }, + utils::{eq_eval_less_or_equal_than, eval_wellform_address_vec, get_challenge_pows}, }; use super::{ diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs index 89271277a..424d528ba 100644 --- a/ceno_zkvm/src/stats.rs +++ b/ceno_zkvm/src/stats.rs @@ -229,7 +229,7 @@ impl Report { let num_instances = zkvm_witnesses .clone() .into_iter_sorted() - .map(|(key, value)| (key, value.num_instances())) + .map(|(key, value)| (key, value[0].num_instances())) .collect::>(); Self::new(static_report, num_instances, program_name) } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index efaad8ca7..7df6fd895 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -5,7 +5,7 @@ use crate::{ instructions::Instruction, state::StateCircuit, tables::TableCircuit, - witness::{LkMultiplicity, RowMajorMatrix}, + witness::LkMultiplicity, }; use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; use ff_ext::ExtensionField; @@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::collections::{BTreeMap, HashMap}; use strum_macros::EnumIter; use sumcheck::structs::IOPProverMessage; +use witness::RowMajorMatrix; pub struct TowerProver; @@ -244,18 +245,19 @@ impl ZKVMFixedTraces { #[derive(Default, Clone)] pub struct ZKVMWitnesses { witnesses_opcodes: BTreeMap>, - witnesses_tables: BTreeMap>, + /// table witness format: [witness, structural_witness] + witnesses_tables: BTreeMap; 2]>, lk_mlts: BTreeMap, combined_lk_mlt: Option>>, } impl ZKVMWitnesses { - pub fn get_opcode_witness(&self, name: &String) -> Option> { - self.witnesses_opcodes.get(name).cloned() + pub fn get_opcode_witness(&self, name: &String) -> Option<&RowMajorMatrix> { + self.witnesses_opcodes.get(name) } - pub fn get_table_witness(&self, name: &String) -> Option> { - self.witnesses_tables.get(name).cloned() + pub fn get_table_witness(&self, name: &String) -> Option<&[RowMajorMatrix; 2]> { + self.witnesses_tables.get(name) } pub fn get_lk_mlt(&self, name: &String) -> Option<&LkMultiplicity> { @@ -343,8 +345,17 @@ impl ZKVMWitnesses { } /// Iterate opcode circuits, then table circuits, sorted by name. - pub fn into_iter_sorted(self) -> impl Iterator)> { - chain(self.witnesses_opcodes, self.witnesses_tables) + pub fn into_iter_sorted( + self, + ) -> impl Iterator>)> { + chain( + self.witnesses_opcodes + .into_iter() + .map(|(name, witnesses)| (name, vec![witnesses])), + self.witnesses_tables + .into_iter() + .map(|(name, witnesses)| (name, witnesses.to_vec())), + ) } } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 8e736228d..7f6efd821 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,6 +1,7 @@ -use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use crate::{circuit_builder::CircuitBuilder, error::ZKVMError}; use ff_ext::ExtensionField; use std::collections::HashMap; +use witness::RowMajorMatrix; mod range; pub use range::*; @@ -36,5 +37,5 @@ pub trait TableCircuit { num_structural_witin: usize, multiplicity: &[HashMap], input: &Self::WitnessInput, - ) -> Result, ZKVMError>; + ) -> Result<[RowMajorMatrix; 2], ZKVMError>; } diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index 45c3e123b..9fa54c719 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -6,9 +6,9 @@ use std::{collections::HashMap, marker::PhantomData}; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit, - witness::RowMajorMatrix, }; use ff_ext::ExtensionField; +use witness::RowMajorMatrix; /// Use this trait as parameter to OpsTableCircuit. pub trait OpsTable { @@ -60,7 +60,7 @@ impl TableCircuit for OpsTableCircuit num_structural_witin: usize, multiplicity: &[HashMap], _input: &(), - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; config.assign_instances(num_witin, num_structural_witin, multiplicity, OP::len()) } diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index a2f5ee937..e65f80d20 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -2,18 +2,15 @@ use ff_ext::{ExtensionField, SmallField}; use itertools::Itertools; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use std::collections::HashMap; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, - instructions::InstancePaddingStrategy, - scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, - witness::RowMajorMatrix, }; #[derive(Clone, Debug)] @@ -59,15 +56,11 @@ impl OpTableConfig { let mut fixed = RowMajorMatrix::::new(content.len(), num_fixed, InstancePaddingStrategy::Default); - fixed - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(content.into_par_iter()) - .for_each(|(row, abc)| { - for (col, val) in self.abc.iter().zip(abc.iter()) { - set_fixed_val!(row, *col, F::from_v(*val)); - } - }); + fixed.par_rows_mut().zip(content).for_each(|(row, abc)| { + for (col, val) in self.abc.iter().zip(abc.iter()) { + set_fixed_val!(row, *col, F::from_v(*val)); + } + }); fixed } @@ -78,26 +71,20 @@ impl OpTableConfig { num_structural_witin: usize, multiplicity: &HashMap, length: usize, - ) -> Result, ZKVMError> { - let mut witness = RowMajorMatrix::::new( - length, - num_witin + num_structural_witin, - InstancePaddingStrategy::Default, - ); + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { + assert_eq!(num_structural_witin, 0); + let mut witness = + RowMajorMatrix::::new(length, num_witin, InstancePaddingStrategy::Default); let mut mlts = vec![0; length]; for (idx, mlt) in multiplicity { mlts[*idx as usize] = *mlt; } - witness - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(mlts.into_par_iter()) - .for_each(|(row, mlt)| { - set_val!(row, self.mlt, F::from_v(mlt as u64)); - }); + witness.par_rows_mut().zip(mlts).for_each(|(row, mlt)| { + set_val!(row, self.mlt, F::from_v(mlt as u64)); + }); - Ok(witness) + Ok([witness, RowMajorMatrix::empty()]) } } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 7b3f792c0..e80943ced 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -4,13 +4,10 @@ use crate::{ circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, - instructions::InstancePaddingStrategy, - scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, tables::TableCircuit, utils::i64_to_base, - witness::RowMajorMatrix, }; use ceno_emul::{ InsnFormat, InsnFormat::*, InsnKind::*, Instruction, PC_STEP_SIZE, Program, WORD_SIZE, @@ -18,7 +15,7 @@ use ceno_emul::{ use ff_ext::{ExtensionField, FieldInto, SmallField}; use itertools::Itertools; use p3_field::PrimeCharacteristicRing; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; /// This structure establishes the order of the fields in instruction records, common to the program table and circuit fetches. #[derive(Clone, Debug)] @@ -147,9 +144,8 @@ impl TableCircuit for ProgramTableCircuit { ); fixed - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip((0..num_instructions).into_par_iter()) + .par_rows_mut() + .zip(0..num_instructions) .for_each(|(row, i)| { let pc = pc_base + (i * PC_STEP_SIZE) as u32; let insn = program.instructions[i]; @@ -170,7 +166,7 @@ impl TableCircuit for ProgramTableCircuit { num_structural_witin: usize, multiplicity: &[HashMap], program: &Program, - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { let multiplicity = &multiplicity[ROMType::Instruction as usize]; let mut prog_mlt = vec![0_usize; program.instructions.len()]; @@ -184,15 +180,11 @@ impl TableCircuit for ProgramTableCircuit { num_witin + num_structural_witin, InstancePaddingStrategy::Default, ); - witness - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(prog_mlt.into_par_iter()) - .for_each(|(row, mlt)| { - set_val!(row, config.mlt, E::BaseField::from_u64(mlt as u64)); - }); + witness.par_rows_mut().zip(prog_mlt).for_each(|(row, mlt)| { + set_val!(row, config.mlt, E::BaseField::from_u64(mlt as u64)); + }); - Ok(witness) + Ok([witness, RowMajorMatrix::empty()]) } } @@ -242,6 +234,6 @@ mod tests { &program, ) .unwrap(); - check(&witness); + check(&witness[0]); } } diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index ef4fff37c..63f591071 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -2,14 +2,13 @@ use std::{collections::HashMap, marker::PhantomData}; use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; use ff_ext::ExtensionField; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::InstancePaddingStrategy, structs::{ProgramParams, RAMType}, tables::TableCircuit, - witness::RowMajorMatrix, }; use super::ram_impl::{DynVolatileRamTableConfig, NonVolatileTableConfig, PubIOTableConfig}; @@ -96,7 +95,7 @@ impl TableCirc num_structural_witin: usize, _multiplicity: &[HashMap], final_v: &Self::WitnessInput, - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { // assume returned table is well-formed include padding config.assign_instances(num_witin, num_structural_witin, final_v) } @@ -144,7 +143,7 @@ impl TableCirc num_structural_witin: usize, _multiplicity: &[HashMap], final_cycles: &[Cycle], - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { // assume returned table is well-formed including padding config.assign_instances(num_witin, num_structural_witin, final_cycles) } @@ -216,7 +215,7 @@ impl TableC num_structural_witin: usize, _multiplicity: &[HashMap], final_v: &Self::WitnessInput, - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { // assume returned table is well-formed include padding config.assign_instances(num_witin, num_structural_witin, final_v) } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index ee75405a1..8b2d0d7f1 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -3,20 +3,15 @@ use std::{marker::PhantomData, sync::Arc}; use ceno_emul::{Addr, Cycle, WORD_SIZE}; use ff_ext::{ExtensionField, SmallField}; use itertools::Itertools; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, expression::{Expression, Fixed, StructuralWitIn, ToExpr, WitIn}, - instructions::{ - InstancePaddingStrategy, - riscv::constants::{LIMB_BITS, LIMB_MASK}, - }, - scheme::constants::MIN_PAR_SIZE, + instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, set_fixed_val, set_val, structs::ProgramParams, - witness::RowMajorMatrix, }; use ff_ext::FieldInto; @@ -127,9 +122,8 @@ impl NonVolatileTableConfig NonVolatileTableConfig Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { + assert_eq!(num_structural_witin, 0); let mut final_table = RowMajorMatrix::::new( NVRAM::len(&self.params), - num_witin + num_structural_witin, + num_witin, InstancePaddingStrategy::Default, ); final_table - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip_eq(final_mem.into_par_iter()) + .par_rows_mut() + .zip_eq(final_mem) .for_each(|(row, rec)| { if let Some(final_v) = &self.final_v { if final_v.len() == 1 { @@ -180,7 +174,7 @@ impl NonVolatileTableConfig PubIOTableConfig { assert_eq!(init_table.num_padding_instances(), 0); init_table - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip_eq(io_addrs.into_par_iter()) + .par_rows_mut() + .zip_eq(io_addrs) .for_each(|(row, addr)| { set_fixed_val!(row, self.addr, (*addr as u64).into_f()); }); @@ -281,22 +274,22 @@ impl PubIOTableConfig { num_witin: usize, num_structural_witin: usize, final_cycles: &[Cycle], - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { + assert_eq!(num_structural_witin, 0); let mut final_table = RowMajorMatrix::::new( NVRAM::len(&self.params), - num_witin + num_structural_witin, + num_witin, InstancePaddingStrategy::Default, ); final_table - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip_eq(final_cycles.into_par_iter()) + .par_rows_mut() + .zip_eq(final_cycles) .for_each(|(row, &cycle)| { set_val!(row, self.final_cycle, cycle); }); - Ok(final_table) + Ok([final_table, RowMajorMatrix::empty()]) } } @@ -388,36 +381,31 @@ impl DynVolatileRamTableConfig num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); - let offset_addr = StructuralWitIn { - id: self.addr.id + (num_witin as u16), - ..self.addr - }; - let params = self.params.clone(); - let padding_fn = move |row: u64, col: u64| { - if col == offset_addr.id as u64 { - DVRAM::addr(¶ms, row as usize) as u64 - } else { - 0u64 - } + let addr_id = self.addr.id as u64; + let addr_padding_fn = move |row: u64, col: u64| { + assert_eq!(col, addr_id); + DVRAM::addr(¶ms, row as usize) as u64 }; - let mut final_table = RowMajorMatrix::::new( + let mut witness = + RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( final_mem.len(), - num_witin + num_structural_witin, - InstancePaddingStrategy::Custom(Arc::new(padding_fn)), + num_structural_witin, + InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), ); - final_table - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(final_mem.into_par_iter()) + witness + .par_rows_mut() + .zip(structural_witness.par_rows_mut()) + .zip(final_mem) .enumerate() - .for_each(|(i, (row, rec))| { + .for_each(|(i, ((row, structural_row), rec))| { assert_eq!(rec.addr, DVRAM::addr(&self.params, i)); if self.final_v.len() == 1 { @@ -432,10 +420,10 @@ impl DynVolatileRamTableConfig } set_val!(row, self.final_cycle, rec.cycle); - set_val!(row, offset_addr, rec.addr as u64); + set_val!(structural_row, self.addr, rec.addr as u64); }); - Ok(final_table) + Ok([witness, structural_witness]) } } @@ -447,7 +435,6 @@ mod tests { circuit_builder::{CircuitBuilder, ConstraintSystem}, structs::ProgramParams, tables::{DynVolatileRamTable, HintsCircuit, HintsTable, MemFinalRecord, TableCircuit}, - utils::next_pow2_instance_padding, witness::LkMultiplicity, }; @@ -456,6 +443,7 @@ mod tests { use itertools::Itertools; use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as F; + use witness::next_pow2_instance_padding; #[test] fn test_well_formed_address_padding() { @@ -475,7 +463,7 @@ mod tests { value: 0, }) .collect_vec(); - let wit = HintsCircuit::::assign_instances( + let [_, structural_witness] = HintsCircuit::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -491,12 +479,14 @@ mod tests { .position(|name| name == "riscv/RAM_Memory_HintsTable/addr") .unwrap(); - let addr_padded_view = wit.column_padded(addr_column + cb.cs.num_witin as usize); + let addr_padded_view = structural_witness.column_padded(addr_column); // Expect addresses to proceed consecutively inside the padding as well let expected = successors(Some(addr_padded_view[0]), |idx| { Some(*idx + F::from_u64(WORD_SIZE as u64)) }) - .take(next_pow2_instance_padding(wit.num_instances())) + .take(next_pow2_instance_padding( + structural_witness.num_instances(), + )) .collect::>(); assert_eq!(addr_padded_view, expected) diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index 2ecb6c6f3..f528cc0d4 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -5,10 +5,10 @@ use super::range_impl::RangeTableConfig; use std::{collections::HashMap, marker::PhantomData}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, instructions::InstancePaddingStrategy, - structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix, + circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit, }; use ff_ext::ExtensionField; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; /// Use this trait as parameter to RangeTableCircuit. pub trait RangeTable { @@ -53,8 +53,9 @@ impl TableCircuit for RangeTableCircuit num_structural_witin: usize, multiplicity: &[HashMap], _input: &(), - ) -> Result, ZKVMError> { + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize]; + config.assign_instances( num_witin, num_structural_witin, diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 4671b123d..32fd0e7d8 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -1,18 +1,15 @@ //! The implementation of range tables. No generics. use ff_ext::{ExtensionField, SmallField}; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use std::collections::HashMap; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, expression::{StructuralWitIn, ToExpr, WitIn}, - instructions::InstancePaddingStrategy, - scheme::constants::MIN_PAR_SIZE, set_val, structs::ROMType, - witness::RowMajorMatrix, }; #[derive(Clone, Debug)] @@ -53,10 +50,12 @@ impl RangeTableConfig { multiplicity: &HashMap, content: Vec, length: usize, - ) -> Result, ZKVMError> { - let mut witness = RowMajorMatrix::::new( + ) -> Result<[RowMajorMatrix; 2], ZKVMError> { + let mut witness: RowMajorMatrix = + RowMajorMatrix::::new(length, num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( length, - num_witin + num_structural_witin, + num_structural_witin, InstancePaddingStrategy::Default, ); @@ -65,21 +64,16 @@ impl RangeTableConfig { mlts[*idx as usize] = *mlt; } - let offset_range = StructuralWitIn { - id: self.range.id + (num_witin as u16), - ..self.range - }; - witness - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(mlts.into_par_iter()) - .zip(content.into_par_iter()) - .for_each(|((row, mlt), i)| { + .par_rows_mut() + .zip(structural_witness.par_rows_mut()) + .zip(mlts) + .zip(content) + .for_each(|(((row, structural_row), mlt), i)| { set_val!(row, self.mlt, F::from_u64(mlt as u64)); - set_val!(row, offset_range, F::from_u64(i)); + set_val!(structural_row, self.range, F::from_u64(i)); }); - Ok(witness) + Ok([witness, structural_witness]) } } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 720b0c5db..5ec14bec9 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -182,11 +182,6 @@ pub fn transpose(v: Vec>) -> Vec> { .collect() } -/// get next power of 2 instance with minimal size 2 -pub fn next_pow2_instance_padding(num_instance: usize) -> usize { - num_instance.next_power_of_two().max(2) -} - pub fn display_hashmap(map: &HashMap) -> String { format!( "[{}]", diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index d471ce958..26c0542d8 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -1,27 +1,18 @@ use itertools::izip; -use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; -use p3_field::{Field, PrimeCharacteristicRing}; -use rayon::{ - iter::{IntoParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, -}; use std::{ cell::RefCell, collections::HashMap, fmt::Debug, hash::Hash, mem::{self}, - ops::{AddAssign, Index}, - slice::{Chunks, ChunksMut}, + ops::AddAssign, sync::Arc, }; use thread_local::ThreadLocal; use crate::{ - instructions::InstancePaddingStrategy, structs::ROMType, tables::{AndTable, LtuTable, OpsTable, OrTable, PowTable, XorTable}, - utils::next_pow2_instance_padding, }; #[macro_export] @@ -38,94 +29,6 @@ macro_rules! set_fixed_val { }; } -#[derive(Clone)] -pub struct RowMajorMatrix { - // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance - values: Vec, - num_col: usize, - padding_strategy: InstancePaddingStrategy, -} - -impl RowMajorMatrix { - pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self { - RowMajorMatrix { - values: (0..num_rows * num_col) - .into_par_iter() - .map(|_| T::default()) - .collect(), - num_col, - padding_strategy, - } - } - - pub fn num_padding_instances(&self) -> usize { - next_pow2_instance_padding(self.num_instances()) - self.num_instances() - } - - pub fn num_instances(&self) -> usize { - self.values.len() / self.num_col - } - - pub fn iter_rows(&self) -> Chunks { - self.values.chunks(self.num_col) - } - - pub fn iter_mut(&mut self) -> ChunksMut { - self.values.chunks_mut(self.num_col) - } - - pub fn par_iter_mut(&mut self) -> rayon::slice::ChunksMut { - self.values.par_chunks_mut(self.num_col) - } - - pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut { - self.values.par_chunks_mut(num_rows * self.num_col) - } - - // Returns column number `column`, padded appropriately according to the stored strategy - pub fn column_padded(&self, column: usize) -> Vec { - let num_instances = self.num_instances(); - let num_padding_instances = self.num_padding_instances(); - - let padding_iter = (num_instances..num_instances + num_padding_instances).map(|i| { - match &self.padding_strategy { - InstancePaddingStrategy::Custom(fun) => T::from_u64(fun(i as u64, column as u64)), - InstancePaddingStrategy::RepeatLast if num_instances > 0 => { - self[num_instances - 1][column] - } - _ => T::default(), - } - }); - - self.values - .iter() - .skip(column) - .step_by(self.num_col) - .copied() - .chain(padding_iter) - .collect::>() - } -} - -impl RowMajorMatrix { - pub fn into_mles>( - self, - ) -> Vec> { - (0..self.num_col) - .into_par_iter() - .map(|i| self.column_padded(i).into_mle()) - .collect() - } -} - -impl Index for RowMajorMatrix { - type Output = [F]; - - fn index(&self, idx: usize) -> &Self::Output { - &self.values[self.num_col * idx..][..self.num_col] - } -} - pub type MultiplicityRaw = [HashMap; mem::variant_count::()]; #[derive(Clone, Default, Debug)] diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 54305ecc7..ee7e68d22 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -26,6 +26,7 @@ num-bigint = "0.4" num-integer = "0.1" p3-field.workspace = true p3-goldilocks.workspace = true +p3-matrix.workspace = true p3-mds.workspace = true p3-symmetric.workspace = true plonky2.workspace = true @@ -36,6 +37,7 @@ rayon = { workspace = true, optional = true } serde.workspace = true transcript = { path = "../transcript" } whir = { path = "../whir", features = ["ceno"] } +witness = { path = "../witness" } zeroize = "1.8" [dev-dependencies] diff --git a/mpcs/benches/basefold.rs b/mpcs/benches/basefold.rs index 8c574818f..fde29a2e9 100644 --- a/mpcs/benches/basefold.rs +++ b/mpcs/benches/basefold.rs @@ -17,7 +17,9 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, virtual_poly::ArcMultilinearExtension, }; +use rand::rngs::OsRng; use transcript::{BasicTranscript, Transcript}; +use witness::RowMajorMatrix; type PcsGoldilocksRSCode = Basefold; type PcsGoldilocksBasecode = Basefold; @@ -246,14 +248,23 @@ fn bench_simple_batch_commit_open_verify_goldilocks(num_vars); let mut transcript = T::new(b"BaseFold"); - let polys = gen_rand_polys(|_| num_vars, batch_size, switch.gen_rand_poly); - let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + let rmm = RowMajorMatrix::rand(&mut OsRng, 1 << num_vars, batch_size); + let polys = rmm.to_mles(); + let comm = Pcs::batch_commit_and_write(&pp, rmm, &mut transcript).unwrap(); group.bench_function( BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + let instant = std::time::Instant::now(); + let rmm = RowMajorMatrix::rand(&mut OsRng, 1 << num_vars, batch_size); + Pcs::batch_commit(&pp, rmm).unwrap(); + let elapsed = instant.elapsed(); + time += elapsed; + } + time }) }, ); diff --git a/mpcs/benches/whir.rs b/mpcs/benches/whir.rs index 615921346..1b16b60ba 100644 --- a/mpcs/benches/whir.rs +++ b/mpcs/benches/whir.rs @@ -6,11 +6,13 @@ use ff_ext::GoldilocksExt2; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, WhirDefault, - test_util::{gen_rand_poly_base, gen_rand_polys, get_point_from_challenge, setup_pcs}, + test_util::{gen_rand_poly_base, get_point_from_challenge, setup_pcs}, }; use multilinear_extensions::{mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension}; +use rand::rngs::OsRng; use transcript::{BasicTranscript, Transcript}; +use witness::RowMajorMatrix; type T = BasicTranscript; type E = GoldilocksExt2; @@ -95,17 +97,26 @@ fn bench_simple_batch_commit_open_verify_goldilocks(num_vars); let mut transcript = T::new(b"BaseFold"); - let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly_base); - let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); group.bench_function( BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), |b| { - b.iter(|| { - Pcs::batch_commit(&pp, &polys).unwrap(); + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + let rmm = RowMajorMatrix::rand(&mut OsRng, 1 << num_vars, batch_size); + let instant = std::time::Instant::now(); + Pcs::batch_commit(&pp, rmm).unwrap(); + let elapsed = instant.elapsed(); + time += elapsed; + } + time }) }, ); + let rmm = RowMajorMatrix::rand(&mut OsRng, 1 << num_vars, batch_size); + let polys = rmm.to_mles(); + let comm = Pcs::batch_commit(&pp, rmm).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); transcript.append_field_element_exts(&evals); diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index ffa151921..e095cc2f4 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -358,8 +358,9 @@ where fn batch_commit( pp: &Self::ProverParam, - polys: &[DenseMultilinearExtension], + rmm: witness::RowMajorMatrix<::BaseField>, ) -> Result { + let polys = rmm.to_mles(); // assumptions // 1. there must be at least one polynomial // 2. all polynomials must exist in the same field type diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index a2c5f55b8..ac161bda8 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -6,6 +6,7 @@ use serde::{Serialize, de::DeserializeOwned}; use std::fmt::Debug; use transcript::{BasicTranscript, Transcript}; use util::hash::Digest; +use witness::RowMajorMatrix; pub mod sum_check; pub mod util; @@ -49,17 +50,17 @@ pub fn pcs_commit_and_write>( pp: &Pcs::ProverParam, - polys: &[DenseMultilinearExtension], + rmm: RowMajorMatrix<::BaseField>, ) -> Result { - Pcs::batch_commit(pp, polys) + Pcs::batch_commit(pp, rmm) } pub fn pcs_batch_commit_and_write>( pp: &Pcs::ProverParam, - polys: &[DenseMultilinearExtension], + rmm: RowMajorMatrix<::BaseField>, transcript: &mut impl Transcript, ) -> Result { - Pcs::batch_commit_and_write(pp, polys, transcript) + Pcs::batch_commit_and_write(pp, rmm, transcript) } pub fn pcs_open>( @@ -149,15 +150,15 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn batch_commit( pp: &Self::ProverParam, - polys: &[DenseMultilinearExtension], + polys: RowMajorMatrix, ) -> Result; fn batch_commit_and_write( pp: &Self::ProverParam, - polys: &[DenseMultilinearExtension], + rmm: RowMajorMatrix<::BaseField>, transcript: &mut impl Transcript, ) -> Result { - let comm = Self::batch_commit(pp, polys)?; + let comm = Self::batch_commit(pp, rmm)?; Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; Ok(comm) } @@ -383,10 +384,14 @@ pub mod test_util { mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension, }; use rand::rngs::OsRng; + #[cfg(test)] + use rand::{distributions::Standard, prelude::Distribution}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; #[cfg(test)] use transcript::BasicTranscript; use transcript::Transcript; + #[cfg(test)] + use witness::RowMajorMatrix; pub fn setup_pcs>( num_vars: usize, @@ -584,22 +589,24 @@ pub mod test_util { #[cfg(test)] pub(super) fn run_simple_batch_commit_open_verify( - gen_rand_poly: fn(usize) -> DenseMultilinearExtension, + _gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, batch_size: usize, ) where E: ExtensionField, Pcs: PolynomialCommitmentScheme, + Standard: Distribution, { for num_vars in num_vars_start..num_vars_end { let (pp, vp) = setup_pcs::(num_vars); let (comm, evals, proof, challenge) = { let mut transcript = BasicTranscript::new(b"BaseFold"); - let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly); - let comm = - Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); + let rmm = + RowMajorMatrix::::rand(&mut OsRng, 1 << num_vars, batch_size); + let polys = rmm.to_mles(); + let comm = Pcs::batch_commit_and_write(&pp, rmm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); transcript.append_field_element_exts(&evals); diff --git a/mpcs/src/whir.rs b/mpcs/src/whir.rs index 637d8f35b..e0418372b 100644 --- a/mpcs/src/whir.rs +++ b/mpcs/src/whir.rs @@ -105,9 +105,10 @@ where fn batch_commit( pp: &Self::ProverParam, - polys: &[multilinear_extensions::mle::DenseMultilinearExtension], + polys: witness::RowMajorMatrix, ) -> Result { - let witness = WhirInnerT::::batch_commit(pp, &polys2whir(polys)) + let polys = polys.to_mles(); + let witness = WhirInnerT::::batch_commit(pp, &polys2whir(&polys)) .map_err(crate::Error::WhirError)?; Ok(witness) diff --git a/witness/Cargo.toml b/witness/Cargo.toml new file mode 100644 index 000000000..3d6a9c40d --- /dev/null +++ b/witness/Cargo.toml @@ -0,0 +1,22 @@ +[package] +categories.workspace = true +description = "Witness for Ceno" +edition.workspace = true +keywords.workspace = true +license.workspace = true +name = "witness" +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +multilinear_extensions = { version = "0", path = "../multilinear_extensions" } + +ff_ext = { path = "../ff_ext" } +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-matrix.workspace = true +rayon.workspace = true +serde.workspace = true + +rand.workspace = true diff --git a/witness/src/lib.rs b/witness/src/lib.rs new file mode 100644 index 000000000..9db2ec75d --- /dev/null +++ b/witness/src/lib.rs @@ -0,0 +1,169 @@ +use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; +use p3_field::{Field, PrimeCharacteristicRing}; +use p3_matrix::{Matrix, bitrev::BitReversableMatrix}; +use rand::{Rng, distributions::Standard, prelude::Distribution}; +use rayon::{ + iter::{IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; +use std::{ + ops::{Deref, DerefMut, Index}, + slice::{Chunks, ChunksMut}, + sync::Arc, +}; + +/// get next power of 2 instance with minimal size 2 +pub fn next_pow2_instance_padding(num_instance: usize) -> usize { + num_instance.next_power_of_two().max(2) +} + +#[derive(Clone)] +pub enum InstancePaddingStrategy { + // Pads with default values of underlying type + // Usually zero, but check carefully + Default, + // Pads by repeating last row + RepeatLast, + // Custom strategy consists of a closure + // `pad(i, j) = padding value for cell at row i, column j` + // pad should be able to cross thread boundaries + Custom(Arc u64 + Send + Sync>), +} + +#[derive(Clone)] +pub struct RowMajorMatrix { + inner: p3_matrix::dense::RowMajorMatrix, + padding_strategy: InstancePaddingStrategy, +} + +impl RowMajorMatrix { + pub fn rand(rng: &mut R, rows: usize, cols: usize) -> Self + where + Standard: Distribution, + { + Self { + inner: p3_matrix::dense::RowMajorMatrix::rand(rng, rows, cols), + padding_strategy: InstancePaddingStrategy::Default, + } + } + pub fn empty() -> Self { + Self { + inner: p3_matrix::dense::RowMajorMatrix::new(vec![], 0), + padding_strategy: InstancePaddingStrategy::Default, + } + } + pub fn into_default_padded_p3_rmm( + self, + is_bit_reserse: bool, + ) -> p3_matrix::dense::RowMajorMatrix { + let padded_height = next_pow2_instance_padding(self.num_instances()); + let mut inner = self.inner; + if is_bit_reserse { + inner = inner.bit_reverse_rows().to_row_major_matrix(); + } + inner.pad_to_height(padded_height, T::default()); + inner + } + + pub fn n_col(&self) -> usize { + self.inner.width + } + + pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self { + let value = (0..num_rows * num_col) + .into_par_iter() + .map(|_| T::default()) + .collect(); + RowMajorMatrix { + inner: p3_matrix::dense::RowMajorMatrix::new(value, num_col), + padding_strategy, + } + } + + pub fn num_padding_instances(&self) -> usize { + next_pow2_instance_padding(self.num_instances()) - self.num_instances() + } + + pub fn num_instances(&self) -> usize { + self.inner.height() + } + + pub fn iter_rows(&self) -> Chunks { + self.inner.values.chunks(self.inner.width) + } + + pub fn iter_mut(&mut self) -> ChunksMut { + self.inner.values.chunks_mut(self.inner.width) + } + + pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut { + self.inner + .values + .par_chunks_mut(num_rows * self.inner.width) + } + + // Returns column number `column`, padded appropriately according to the stored strategy + pub fn column_padded(&self, column: usize) -> Vec { + let n_column = self.n_col(); + let num_instances = self.num_instances(); + let num_padding_instances = self.num_padding_instances(); + + let padding_iter = (num_instances..num_instances + num_padding_instances).map(|i| { + match &self.padding_strategy { + InstancePaddingStrategy::Custom(fun) => T::from_u64(fun(i as u64, column as u64)), + InstancePaddingStrategy::RepeatLast if num_instances > 0 => { + self[num_instances - 1][column] + } + _ => T::default(), + } + }); + + self.inner + .values + .iter() + .skip(column) + .step_by(n_column) + .copied() + .chain(padding_iter) + .collect::>() + } +} + +impl RowMajorMatrix { + pub fn to_mles>( + &self, + ) -> Vec> { + let n_column = self.inner.width; + (0..n_column) + .into_par_iter() + .map(|i| self.column_padded(i).into_mle()) + .collect() + } +} + +impl Deref + for RowMajorMatrix +{ + type Target = p3_matrix::dense::DenseMatrix; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut + for RowMajorMatrix +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl Index for RowMajorMatrix { + type Output = [F]; + + fn index(&self, idx: usize) -> &Self::Output { + let num_col = self.n_col(); + &self.inner.values[num_col * idx..][..num_col] + } +}