Skip to content

Commit

Permalink
refactor rowmajormatrix witness
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Feb 28, 2025
1 parent 6d764ca commit 55d2321
Show file tree
Hide file tree
Showing 33 changed files with 440 additions and 347 deletions.
18 changes: 18 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ members = [
"sumcheck",
"transcript",
"whir",
"witness",
]
resolver = "2"

Expand All @@ -37,6 +38,7 @@ num-traits = "0.2"
p3-challenger = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-field = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-goldilocks = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-matrix = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-mds = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-poseidon = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-poseidon2 = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ mpcs = { path = "../mpcs" }
multilinear_extensions = { version = "0", path = "../multilinear_extensions" }
sumcheck = { version = "0", path = "../sumcheck" }
transcript = { path = "../transcript" }
witness = { path = "../witness" }

itertools.workspace = true
num-traits.workspace = true
p3-field.workspace = true
p3-goldilocks.workspace = true
p3-matrix.workspace = true
p3-mds.workspace = true
paste.workspace = true
poseidon.workspace = true
Expand Down
24 changes: 9 additions & 15 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::time::Duration;

use ark_std::test_rng;
use ceno_zkvm::{
self,
instructions::{Instruction, riscv::arith::AddInstruction},
Expand All @@ -10,12 +9,13 @@ use ceno_zkvm::{
use criterion::*;

use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES;
use ff_ext::{FromUniformBytes, GoldilocksExt2};
use ff_ext::GoldilocksExt2;
use itertools::Itertools;
use mpcs::{BasefoldDefault, PolynomialCommitmentScheme};
use multilinear_extensions::mle::IntoMLE;
use p3_goldilocks::Goldilocks;

use rand::rngs::OsRng;
use transcript::{BasicTranscript, Transcript};
use witness::RowMajorMatrix;

cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
Expand Down Expand Up @@ -74,22 +74,16 @@ fn bench_add(c: &mut Criterion) {
let mut time = Duration::new(0, 0);
for _ in 0..iters {
// generate mock witness
let mut rng = test_rng();
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
})
.collect_vec();
let rmm =
RowMajorMatrix::rand(&mut OsRng, num_instances, num_witin as usize);
let polys = rmm.to_mles();

let instant = std::time::Instant::now();
let num_instances = 1 << instance_num_vars;
let mut transcript = BasicTranscript::new(b"riscv");
let commit =
Pcs::batch_commit_and_write(&prover.pk.pp, &wits_in, &mut transcript)
Pcs::batch_commit_and_write(&prover.pk.pp, rmm, &mut transcript)
.unwrap();
let challenges = [
transcript.read_challenge().elements,
Expand All @@ -101,7 +95,7 @@ fn bench_add(c: &mut Criterion) {
"ADD",
&prover.pk.pp,
&circuit_pk,
wits_in.into_iter().map(|mle| mle.into()).collect_vec(),
polys.into_iter().map(|mle| mle.into()).collect_vec(),
commit,
&[],
num_instances,
Expand Down
11 changes: 5 additions & 6 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ use crate::{
error::ZKVMError,
expression::{Expression, Fixed, Instance, StructuralWitIn, WitIn},
structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId},
witness::RowMajorMatrix,
};

use p3_field::PrimeCharacteristicRing;
use witness::RowMajorMatrix;

/// namespace used for annotation, preserve meta info during circuit construction
#[derive(Clone, Debug, Default, serde::Serialize)]
Expand Down Expand Up @@ -180,15 +181,13 @@ impl<E: ExtensionField> ConstraintSystem<E> {
fixed_traces: Option<RowMajorMatrix<E::BaseField>>,
) -> ProvingKey<E, PCS> {
// transpose from row-major to column-major
let fixed_traces = fixed_traces.map(RowMajorMatrix::into_mles);
let fixed_traces_polys = fixed_traces.as_ref().map(|rmm| rmm.to_mles());

let fixed_commit_wd = fixed_traces
.as_ref()
.map(|traces| PCS::batch_commit(pp, traces).unwrap());
let fixed_commit_wd = fixed_traces.map(|traces| PCS::batch_commit(pp, traces).unwrap());
let fixed_commit = fixed_commit_wd.as_ref().map(PCS::get_pure_commitment);

ProvingKey {
fixed_traces,
fixed_traces: fixed_traces_polys,
fixed_commit_wd,
vk: VerifyingKey {
cs: self,
Expand Down
22 changes: 3 additions & 19 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,12 @@ use rayon::{
iter::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
};
use std::sync::Arc;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
witness::{LkMultiplicity, RowMajorMatrix},
};
use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::LkMultiplicity};

pub mod riscv;
use witness::{InstancePaddingStrategy, RowMajorMatrix};

#[derive(Clone)]
pub enum InstancePaddingStrategy {
// Pads with default values of underlying type
// Usually zero, but check carefully
Default,
// Pads by repeating last row
RepeatLast,
// Custom strategy consists of a closure
// `pad(i, j) = padding value for cell at row i, column j`
// pad should be able to cross thread boundaries
Custom(Arc<dyn Fn(u64, u64) -> u64 + Send + Sync>),
}
pub mod riscv;

pub trait Instruction<E: ExtensionField> {
type InstructionConfig: Send + Sync;
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ mod test {
MockProver::assert_with_expected_errors(
&cb,
&raw_witin
.into_mles()
.to_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,15 @@ impl<E: ExtensionField> MemAddr<E> {
mod test {
use ff_ext::GoldilocksExt2 as E;
use itertools::Itertools;
use witness::{InstancePaddingStrategy, RowMajorMatrix};
use p3_goldilocks::Goldilocks as F;

use crate::{
ROMType,
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
instructions::InstancePaddingStrategy,
scheme::mock_prover::MockProver,
witness::{LkMultiplicity, RowMajorMatrix},
witness::LkMultiplicity,
};

use super::MemAddr;
Expand Down Expand Up @@ -562,7 +562,7 @@ mod test {
MockProver::assert_with_expected_errors(
&cb,
&raw_witin
.into_mles()
.to_mles()
.into_iter()
.map(|v| v.into())
.collect_vec(),
Expand Down
32 changes: 18 additions & 14 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
AndTable, LtuTable, OpsTable, OrTable, PowTable, ProgramTableCircuit, RangeTable,
TableCircuit, U5Table, U8Table, U14Table, U16Table, XorTable,
},
witness::{LkMultiplicity, LkMultiplicityRaw, RowMajorMatrix},
witness::{LkMultiplicity, LkMultiplicityRaw},
};
use ark_std::test_rng;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
Expand All @@ -36,6 +36,7 @@ use std::{
};
use strum::IntoEnumIterator;
use tiny_keccak::{Hasher, Keccak};
use witness::RowMajorMatrix;

const MAX_CONSTRAINT_DEGREE: usize = 2;
const MOCK_PROGRAM_SIZE: usize = 32;
Expand Down Expand Up @@ -747,7 +748,7 @@ Hints:
lkm: Option<LkMultiplicity>,
) {
let wits_in = raw_witin
.into_mles()
.to_mles()
.into_iter()
.map(|v| v.into())
.collect_vec();
Expand Down Expand Up @@ -805,13 +806,20 @@ Hints:

// Process all circuits.
for (circuit_name, cs) in &cs.circuit_css {
let empty_rmm = RowMajorMatrix::empty();
let is_opcode = cs.lk_table_expressions.is_empty()
&& cs.r_table_expressions.is_empty()
&& cs.w_table_expressions.is_empty();
let witness = if is_opcode {
witnesses
.get_opcode_witness(circuit_name)
.unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name))
let [witness, _] = if is_opcode {
&[
witnesses
.get_opcode_witness(circuit_name)
.cloned()
.unwrap_or_else(|| {
panic!("witness for {} should not be None", circuit_name)
}),
empty_rmm,
]
} else {
witnesses
.get_table_witness(circuit_name)
Expand All @@ -827,7 +835,7 @@ Hints:
continue;
}
let mut witness = witness
.into_mles()
.to_mles()
.into_iter()
.map(|w| w.into())
.collect_vec();
Expand All @@ -837,11 +845,7 @@ Hints:
.remove(circuit_name)
.and_then(|fixed| fixed)
.map_or(vec![], |fixed| {
fixed
.into_mles()
.into_iter()
.map(|f| f.into())
.collect_vec()
fixed.to_mles().into_iter().map(|f| f.into()).collect_vec()
});
if is_opcode {
tracing::info!(
Expand Down Expand Up @@ -1249,13 +1253,13 @@ mod tests {
error::ZKVMError,
expression::{ToExpr, WitIn},
gadgets::{AssertLtConfig, IsLtConfig},
instructions::InstancePaddingStrategy,
set_val,
witness::{LkMultiplicity, RowMajorMatrix},
witness::LkMultiplicity,
};
use ff_ext::{FieldInto, GoldilocksExt2};
use multilinear_extensions::mle::IntoMLE;
use p3_goldilocks::Goldilocks;
use witness::InstancePaddingStrategy;

#[derive(Debug)]
struct AssertZeroCircuit {
Expand Down
28 changes: 14 additions & 14 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use sumcheck::{
structs::{IOPProverMessage, IOPProverState},
};
use transcript::{ForkableTranscript, Transcript};
use witness::{RowMajorMatrix, next_pow2_instance_padding};

use crate::{
error::ZKVMError,
Expand All @@ -32,7 +33,7 @@ use crate::{
structs::{
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
},
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads},
utils::{get_challenge_pows, optimal_sumcheck_threads},
virtual_polys::VirtualPolynomials,
};

Expand Down Expand Up @@ -95,29 +96,28 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {

let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true);
// commit to opcode circuits first and then commit to table circuits, sorted by name
for (circuit_name, witness) in witnesses.into_iter_sorted() {
let num_instances = witness.num_instances();
for (circuit_name, mut rmm) in witnesses.into_iter_sorted() {
let witness_rmm = rmm.remove(0);
let structural_witness_rmm = if !rmm.is_empty() {
rmm.remove(0)
} else {
RowMajorMatrix::empty()
};
let num_instances = witness_rmm.num_instances();
let span = entered_span!(
"commit to iteration",
circuit_name = circuit_name,
profiling_2 = true
);
let num_witin = self
.pk
.circuit_pks
.get(&circuit_name)
.unwrap()
.get_cs()
.num_witin;

let (witness, structural_witness) = match num_instances {
0 => (vec![], vec![]),
_ => {
let mut witness = witness.into_mles();
let structural_witness = witness.split_off(num_witin as usize);
let witness = witness_rmm.to_mles();
let structural_witness = structural_witness_rmm.to_mles();
commitments.insert(
circuit_name.clone(),
PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript)
PCS::batch_commit_and_write(&self.pk.pp, witness_rmm, &mut transcript)
.map_err(ZKVMError::PCSError)?,
);

Expand Down Expand Up @@ -162,7 +162,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
{
let (witness, num_instances) = wits
.remove(circuit_name)
.ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?;
.ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string()))?;
if witness.is_empty() {
continue;
}
Expand Down
Loading

0 comments on commit 55d2321

Please sign in to comment.