Skip to content

Commit

Permalink
WIP zkvm crate to p3
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Feb 17, 2025
1 parent 50f4561 commit f1fc090
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 86 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ multilinear_extensions = { version = "0", path = "../multilinear_extensions" }
sumcheck = { version = "0", path = "../sumcheck" }
transcript = { path = "../transcript" }

p3-field.workspace = true
itertools.workspace = true
num-traits.workspace = true
paste.workspace = true
Expand All @@ -50,6 +51,7 @@ cfg-if.workspace = true
criterion.workspace = true
pprof2.workspace = true
proptest.workspace = true
p3-goldilocks.workspace = true

[build-dependencies]
glob = "0.3"
Expand Down
12 changes: 8 additions & 4 deletions ceno_zkvm/src/chip_handler/global_state.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use ff_ext::ExtensionField;

use super::GlobalStateRegisterMachineChipOperations;
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, structs::RAMType,
};

use super::GlobalStateRegisterMachineChipOperations;
use p3_field::FieldAlgebra;

impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitBuilder<'_, E> {
fn state_in(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
let record: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)),
Expression::Constant(E::BaseField::from_canonical_u64(
RAMType::GlobalState as u64,
)),
pc,
ts,
];
Expand All @@ -19,7 +21,9 @@ impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitB

fn state_out(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
let record: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)),
Expression::Constant(E::BaseField::from_canonical_u64(
RAMType::GlobalState as u64,
)),
pc,
ts,
];
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/chip_handler/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::iter::successors;

use crate::expression::Expression;
use ff::Field;
use ff_ext::ExtensionField;
use itertools::izip;
use p3_field::FieldAlgebra;

pub fn rlc_chip_record<E: ExtensionField>(
records: Vec<Expression<E>>,
Expand Down
52 changes: 26 additions & 26 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use std::{
};

use ceno_emul::InsnKind;
use ff::Field;
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use ff_ext::{ExtensionField, SmallField};
use p3_field::FieldAlgebra;

use multilinear_extensions::virtual_poly::ArcMultilinearExtension;

Expand Down Expand Up @@ -358,7 +357,7 @@ impl<E: ExtensionField> Add for Expression<E> {
| (
Expression::Challenge(challenge_id, pow, scalar, offset),
Expression::Constant(c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar, *offset + c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar, *offset + *c1),

// challenge + challenge
(
Expand All @@ -369,16 +368,16 @@ impl<E: ExtensionField> Add for Expression<E> {
Expression::Challenge(
*challenge_id1,
*pow1,
*scalar1 + scalar2,
*offset1 + offset2,
*scalar1 + *scalar2,
*offset1 + *offset2,
)
} else {
Expression::Sum(Box::new(self), Box::new(rhs))
}
}

// constant + constant
(Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 + c2),
(Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 + *c2),

// constant + scaled sum
(c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b))
Expand Down Expand Up @@ -500,13 +499,13 @@ impl<E: ExtensionField> Sub for Expression<E> {
(
Expression::Constant(c1),
Expression::Challenge(challenge_id, pow, scalar, offset),
) => Expression::Challenge(*challenge_id, *pow, *scalar, offset.neg() + c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar, offset.neg() + *c1),

// challenge - constant
(
Expression::Challenge(challenge_id, pow, scalar, offset),
Expression::Constant(c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar, *offset - c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar, *offset - *c1),

// challenge - challenge
(
Expand All @@ -517,16 +516,16 @@ impl<E: ExtensionField> Sub for Expression<E> {
Expression::Challenge(
*challenge_id1,
*pow1,
*scalar1 - scalar2,
*offset1 - offset2,
*scalar1 - *scalar2,
*offset1 - *offset2,
)
} else {
Expression::Sum(Box::new(self), Box::new(-rhs))
}
}

// constant - constant
(Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 - c2),
(Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 - *c2),

// constant - scalesum
(c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) => {
Expand Down Expand Up @@ -705,7 +704,7 @@ impl<E: ExtensionField> Mul for Expression<E> {
| (
Expression::Challenge(challenge_id, pow, scalar, offset),
Expression::Constant(c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar * c1, *offset * c1),
) => Expression::Challenge(*challenge_id, *pow, *scalar * *c1, *offset * *c1),
// challenge * challenge
(
Expression::Challenge(challenge_id1, pow1, s1, offset1),
Expand All @@ -719,8 +718,8 @@ impl<E: ExtensionField> Mul for Expression<E> {
let mut result = Expression::Challenge(
*challenge_id1,
pow1 + pow2,
*s1 * s2,
*offset1 * offset2,
*s1 * *s2,
*offset1 * *offset2,
);

// offset2 * s1 * c1^(pow1)
Expand Down Expand Up @@ -756,7 +755,7 @@ impl<E: ExtensionField> Mul for Expression<E> {
}

// constant * constant
(Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 * c2),
(Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 * *c2),
// scaledsum * constant
(Expression::ScaledSum(x, a, b), c2 @ Expression::Constant(_))
| (c2 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) => {
Expand Down Expand Up @@ -914,7 +913,7 @@ macro_rules! impl_from_unsigned {
$(
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
Expression::Constant(F::from(value as u64))
Expression::Constant(F::from_canonical_u64(value as u64))
}
}
)*
Expand All @@ -929,7 +928,7 @@ macro_rules! impl_from_signed {
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
let reduced = (value as i128).rem_euclid(F::MODULUS_U64 as i128) as u64;
Expression::Constant(F::from(reduced))
Expression::Constant(F::from_canonical_u64(reduced))
}
}
)*
Expand Down Expand Up @@ -967,18 +966,18 @@ pub mod fmt {
)
}
Expression::Challenge(id, pow, scaler, offset) => {
if *pow == 1 && *scaler == 1.into() && *offset == 0.into() {
if *pow == 1 && *scaler == E::ONE && *offset == E::ZERO {
format!("Challenge({})", id)
} else {
let mut s = String::new();
if *scaler != 1.into() {
if *scaler != E::ONE {
write!(s, "{}*", field(scaler)).unwrap();
}
write!(s, "Challenge({})", id,).unwrap();
if *pow > 1 {
write!(s, "^{}", pow).unwrap();
}
if *offset != 0.into() {
if *offset != E::ZERO {
write!(s, "+{}", field(offset)).unwrap();
}
s
Expand Down Expand Up @@ -1025,7 +1024,9 @@ pub mod fmt {
.iter()
.map(|b| base_field::<E::BaseField>(b, false))
.collect::<Vec<String>>();
let only_one_limb = field.as_bases()[1..].iter().all(|&x| x == 0.into());
let only_one_limb = field.as_bases()[1..]
.iter()
.all(|&x| x == E::BaseField::ZERO);

if only_one_limb {
data[0].to_string()
Expand Down Expand Up @@ -1083,12 +1084,11 @@ pub mod fmt {

#[cfg(test)]
mod tests {
use goldilocks::GoldilocksExt2;

use crate::circuit_builder::{CircuitBuilder, ConstraintSystem};

use super::{Expression, ToExpr, fmt};
use ff::Field;
use crate::circuit_builder::{CircuitBuilder, ConstraintSystem};
use ff_ext::GoldilocksExt2;
use p3_field::FieldAlgebra;

#[test]
fn test_expression_arithmetics() {
Expand Down
9 changes: 5 additions & 4 deletions ceno_zkvm/src/expression/monomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ mod tests {
use crate::{expression::Fixed as FixedS, scheme::utils::eval_by_expr_with_fixed};

use super::*;
use ff::Field;
use goldilocks::{Goldilocks as F, GoldilocksExt2 as E};
use ff_ext::{FromUniformBytes, GoldilocksExt2 as E};
use p3_field::FieldAlgebra;
use p3_goldilocks::Goldilocks as F;
use rand_chacha::{ChaChaRng, rand_core::SeedableRng};

#[test]
Expand All @@ -98,8 +99,8 @@ mod tests {
let y = || WitIn(1);
let z = || WitIn(2);
let n = || Constant(104.into());
let m = || Constant(-F::from(599));
let r = || Challenge(0, 1, E::from(1), E::from(0));
let m = || Constant(-F::from_canonical_u64(599));
let r = || Challenge(0, 1, E::ONE, E::ZERO);

let test_exprs: &[Expression<E>] = &[
a() * x() * x(),
Expand Down
11 changes: 8 additions & 3 deletions ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::fmt::Display;

use ceno_emul::{SWord, Word};
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use ff_ext::{ExtensionField, SmallField};
use itertools::izip;

use crate::{
Expand Down Expand Up @@ -242,7 +241,13 @@ impl InnerLtConfig {
lhs: u64,
rhs: u64,
) -> Result<(), ZKVMError> {
self.assign_instance_field(instance, lkm, lhs.into(), rhs.into(), lhs < rhs)
self.assign_instance_field(
instance,
lkm,
F::from_canonical_u64(lhs),
F::from_canonical_u64(rhs),
lhs < rhs,
)
}

/// Assign instance values to this configuration where the ordering is
Expand Down
25 changes: 17 additions & 8 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ff_ext::ExtensionField;
use itertools::Itertools;
use mpcs::PolynomialCommitmentScheme;
use serde::{Deserialize, Serialize};
use p3_field::FieldAlgebra;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::{collections::BTreeMap, fmt::Debug};
use sumcheck::structs::IOPProverMessage;

Expand Down Expand Up @@ -45,6 +46,10 @@ pub struct ZKVMOpcodeProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>
}

#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(
serialize = "E::BaseField: Serialize",
deserialize = "E::BaseField: DeserializeOwned"
))]
pub struct ZKVMTableProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
// tower evaluation at layer 1
pub r_out_evals: Vec<[E; 2]>,
Expand Down Expand Up @@ -98,15 +103,19 @@ impl PublicValues<u32> {
}
pub fn to_vec<E: ExtensionField>(&self) -> Vec<Vec<E::BaseField>> {
vec![
vec![E::BaseField::from((self.exit_code & 0xffff) as u64)],
vec![E::BaseField::from(((self.exit_code >> 16) & 0xffff) as u64)],
vec![E::BaseField::from(self.init_pc as u64)],
vec![E::BaseField::from(self.init_cycle as u64)],
vec![E::BaseField::from(self.end_pc as u64)],
vec![E::BaseField::from(self.end_cycle as u64)],
vec![E::BaseField::from_canonical_u64(
(self.exit_code & 0xffff) as u64,
)],
vec![E::BaseField::from_canonical_u64(
((self.exit_code >> 16) & 0xffff) as u64,
)],
vec![E::BaseField::from_canonical_u64(self.init_pc as u64)],
vec![E::BaseField::from_canonical_u64(self.init_cycle as u64)],
vec![E::BaseField::from_canonical_u64(self.end_pc as u64)],
vec![E::BaseField::from_canonical_u64(self.end_cycle as u64)],
self.public_io
.iter()
.map(|e| E::BaseField::from(*e as u64))
.map(|e| E::BaseField::from_canonical_u64(*e as u64))
.collect(),
]
}
Expand Down
5 changes: 2 additions & 3 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ use crate::{
use ark_std::test_rng;
use base64::{Engine, engine::general_purpose::STANDARD_NO_PAD};
use ceno_emul::{ByteAddr, CENO_PLATFORM, Platform, Program};
use ff::Field;
use ff_ext::ExtensionField;
use ff_ext::{ExtensionField, GoldilocksExt2, SmallField};
use generic_static::StaticTypeMap;
use goldilocks::{GoldilocksExt2, SmallField};
use itertools::{Itertools, chain, enumerate, izip};
use multilinear_extensions::{mle::IntoMLEs, virtual_poly::ArcMultilinearExtension};
use p3_field::FieldAlgebra;
use rand::thread_rng;
use std::{
cmp::max,
Expand Down
Loading

0 comments on commit f1fc090

Please sign in to comment.