Skip to content

Commit

Permalink
initial integration of hyperplonk snark(#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang authored Aug 1, 2022
1 parent 229148e commit a6ea6ac
Show file tree
Hide file tree
Showing 38 changed files with 1,943 additions and 441 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]
members = [
"arithmetic",
"hyperplonk",
"pcs",
"poly-iop",
Expand Down
34 changes: 34 additions & 0 deletions arithmetic/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[package]
name = "arithmetic"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]

ark-ff = { version = "^0.3.0", default-features = false }
ark-std = { version = "^0.3.0", default-features = false }
ark-poly = { version = "^0.3.0", default-features = false }
ark-serialize = { version = "^0.3.0", default-features = false }
ark-bls12-381 = { version = "0.3.0", default-features = false, features = [ "curve" ] }

rand_chacha = { version = "0.3.0", default-features = false }
displaydoc = { version = "0.2.3", default-features = false }
rayon = { version = "1.5.2", default-features = false, optional = true }

[dev-dependencies]
ark-ec = { version = "^0.3.0", default-features = false }

[features]
# default = [ "parallel", "print-trace" ]
default = [ "parallel" ]
parallel = [
"rayon",
"ark-std/parallel",
"ark-ff/parallel",
"ark-poly/parallel"
]
print-trace = [
"ark-std/print-trace"
]
21 changes: 21 additions & 0 deletions arithmetic/src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//! Error module.
use ark_std::string::String;
use displaydoc::Display;

/// A `enum` specifying the possible failure modes of the arithmetics.
#[derive(Display, Debug)]
pub enum ArithErrors {
/// Invalid parameters: {0}
InvalidParameters(String),
/// Should not arrive to this point
ShouldNotArrive,
/// An error during (de)serialization: {0}
SerializationErrors(ark_serialize::SerializationError),
}

impl From<ark_serialize::SerializationError> for ArithErrors {
fn from(e: ark_serialize::SerializationError) -> Self {
Self::SerializationErrors(e)
}
}
7 changes: 7 additions & 0 deletions arithmetic/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod errors;
mod multilinear_polynomial;
mod virtual_polynomial;

pub use errors::ArithErrors;
pub use multilinear_polynomial::{random_zero_mle_list, DenseMultilinearExtension};
pub use virtual_polynomial::{build_eq_x_r, VPAuxInfo, VirtualPolynomial};
33 changes: 33 additions & 0 deletions arithmetic/src/multilinear_polynomial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use ark_ff::PrimeField;
use ark_std::{end_timer, rand::RngCore, start_timer};
use std::rc::Rc;

pub use ark_poly::DenseMultilinearExtension;

// Build a randomize list of mle-s whose sum is zero.
pub fn random_zero_mle_list<F: PrimeField, R: RngCore>(
nv: usize,
degree: usize,
rng: &mut R,
) -> Vec<Rc<DenseMultilinearExtension<F>>> {
let start = start_timer!(|| "sample random zero mle list");

let mut multiplicands = Vec::with_capacity(degree);
for _ in 0..degree {
multiplicands.push(Vec::with_capacity(1 << nv))
}
for _ in 0..(1 << nv) {
multiplicands[0].push(F::zero());
for e in multiplicands.iter_mut().skip(1) {
e.push(F::rand(rng));
}
}

let list = multiplicands
.into_iter()
.map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x)))
.collect();

end_timer!(start);
list
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! This module defines our main mathematical object `VirtualPolynomial`; and
//! various functions associated with it.
use crate::errors::PolyIOPErrors;
use crate::{errors::ArithErrors, multilinear_polynomial::random_zero_mle_list};
use ark_ff::PrimeField;
use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
use ark_serialize::{CanonicalSerialize, SerializationError, Write};
Expand Down Expand Up @@ -84,6 +84,7 @@ impl<F: PrimeField> Add for &VirtualPolynomial<F> {
}
}

// TODO: convert this into a trait
impl<F: PrimeField> VirtualPolynomial<F> {
/// Creates an empty virtual polynomial with `num_variables`.
pub fn new(num_variables: usize) -> Self {
Expand Down Expand Up @@ -129,12 +130,12 @@ impl<F: PrimeField> VirtualPolynomial<F> {
&mut self,
mle_list: impl IntoIterator<Item = Rc<DenseMultilinearExtension<F>>>,
coefficient: F,
) -> Result<(), PolyIOPErrors> {
) -> Result<(), ArithErrors> {
let mle_list: Vec<Rc<DenseMultilinearExtension<F>>> = mle_list.into_iter().collect();
let mut indexed_product = Vec::with_capacity(mle_list.len());

if mle_list.is_empty() {
return Err(PolyIOPErrors::InvalidParameters(
return Err(ArithErrors::InvalidParameters(
"input mle_list is empty".to_string(),
));
}
Expand All @@ -143,7 +144,7 @@ impl<F: PrimeField> VirtualPolynomial<F> {

for mle in mle_list {
if mle.num_vars != self.aux_info.num_variables {
return Err(PolyIOPErrors::InvalidParameters(format!(
return Err(ArithErrors::InvalidParameters(format!(
"product has a multiplicand with wrong number of variables {} vs {}",
mle.num_vars, self.aux_info.num_variables
)));
Expand Down Expand Up @@ -171,11 +172,11 @@ impl<F: PrimeField> VirtualPolynomial<F> {
&mut self,
mle: Rc<DenseMultilinearExtension<F>>,
coefficient: F,
) -> Result<(), PolyIOPErrors> {
) -> Result<(), ArithErrors> {
let start = start_timer!(|| "mul by mle");

if mle.num_vars != self.aux_info.num_variables {
return Err(PolyIOPErrors::InvalidParameters(format!(
return Err(ArithErrors::InvalidParameters(format!(
"product has a multiplicand with wrong number of variables {} vs {}",
mle.num_vars, self.aux_info.num_variables
)));
Expand Down Expand Up @@ -209,11 +210,11 @@ impl<F: PrimeField> VirtualPolynomial<F> {

/// Evaluate the virtual polynomial at point `point`.
/// Returns an error is point.len() does not match `num_variables`.
pub fn evaluate(&self, point: &[F]) -> Result<F, PolyIOPErrors> {
pub fn evaluate(&self, point: &[F]) -> Result<F, ArithErrors> {
let start = start_timer!(|| "evaluation");

if self.aux_info.num_variables != point.len() {
return Err(PolyIOPErrors::InvalidParameters(format!(
return Err(ArithErrors::InvalidParameters(format!(
"wrong number of variables {} vs {}",
self.aux_info.num_variables,
point.len()
Expand Down Expand Up @@ -246,7 +247,7 @@ impl<F: PrimeField> VirtualPolynomial<F> {
num_multiplicands_range: (usize, usize),
num_products: usize,
rng: &mut R,
) -> Result<(Self, F), PolyIOPErrors> {
) -> Result<(Self, F), ArithErrors> {
let start = start_timer!(|| "sample random virtual polynomial");

let mut sum = F::zero();
Expand All @@ -271,7 +272,7 @@ impl<F: PrimeField> VirtualPolynomial<F> {
num_multiplicands_range: (usize, usize),
num_products: usize,
rng: &mut R,
) -> Result<Self, PolyIOPErrors> {
) -> Result<Self, ArithErrors> {
let mut poly = VirtualPolynomial::new(nv);
for _ in 0..num_products {
let num_multiplicands =
Expand All @@ -290,11 +291,11 @@ impl<F: PrimeField> VirtualPolynomial<F> {
// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
//
// This function is used in ZeroCheck.
pub(crate) fn build_f_hat(&self, r: &[F]) -> Result<Self, PolyIOPErrors> {
pub fn build_f_hat(&self, r: &[F]) -> Result<Self, ArithErrors> {
let start = start_timer!(|| "zero check build hat f");

if self.aux_info.num_variables != r.len() {
return Err(PolyIOPErrors::InvalidParameters(format!(
return Err(ArithErrors::InvalidParameters(format!(
"r.len() is different from number of variables: {} vs {}",
r.len(),
self.aux_info.num_variables
Expand All @@ -308,6 +309,19 @@ impl<F: PrimeField> VirtualPolynomial<F> {
end_timer!(start);
Ok(res)
}

/// Print out the evaluation map for testing. Panic if the num_vars > 5.
pub fn print_evals(&self) {
if self.aux_info.num_variables > 5 {
panic!("this function is used for testing only. cannot print more than 5 num_vars")
}
for i in 0..1 << self.aux_info.num_variables {
let point = bit_decompose(i, self.aux_info.num_variables);
let point_fr: Vec<F> = point.iter().map(|&x| F::from(x)).collect();
println!("{} {}", i, self.evaluate(point_fr.as_ref()).unwrap())
}
println!()
}
}

/// Sample a random list of multilinear polynomials.
Expand Down Expand Up @@ -346,41 +360,15 @@ fn random_mle_list<F: PrimeField, R: RngCore>(
(list, sum)
}

// Build a randomize list of mle-s whose sum is zero.
pub fn random_zero_mle_list<F: PrimeField, R: RngCore>(
nv: usize,
degree: usize,
rng: &mut R,
) -> Vec<Rc<DenseMultilinearExtension<F>>> {
let start = start_timer!(|| "sample random zero mle list");

let mut multiplicands = Vec::with_capacity(degree);
for _ in 0..degree {
multiplicands.push(Vec::with_capacity(1 << nv))
}
for _ in 0..(1 << nv) {
multiplicands[0].push(F::zero());
for e in multiplicands.iter_mut().skip(1) {
e.push(F::rand(rng));
}
}

let list = multiplicands
.into_iter()
.map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x)))
.collect();

end_timer!(start);
list
}

// This function build the eq(x, r) polynomial for any given r.
//
// Evaluate
// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
// over r, which is
// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
fn build_eq_x_r<F: PrimeField>(r: &[F]) -> Result<Rc<DenseMultilinearExtension<F>>, PolyIOPErrors> {
pub fn build_eq_x_r<F: PrimeField>(
r: &[F],
) -> Result<Rc<DenseMultilinearExtension<F>>, ArithErrors> {
let start = start_timer!(|| "zero check build eq_x_r");

// we build eq(x,r) from its evaluations
Expand All @@ -407,11 +395,9 @@ fn build_eq_x_r<F: PrimeField>(r: &[F]) -> Result<Rc<DenseMultilinearExtension<F
/// A helper function to build eq(x, r) recursively.
/// This function takes `r.len()` steps, and for each step it requires a maximum
/// `r.len()-1` multiplications.
fn build_eq_x_r_helper<F: PrimeField>(r: &[F], buf: &mut Vec<F>) -> Result<(), PolyIOPErrors> {
fn build_eq_x_r_helper<F: PrimeField>(r: &[F], buf: &mut Vec<F>) -> Result<(), ArithErrors> {
if r.is_empty() {
return Err(PolyIOPErrors::InvalidParameters(
"r length is 0".to_string(),
));
return Err(ArithErrors::InvalidParameters("r length is 0".to_string()));
} else if r.len() == 1 {
// initializing the buffer with [1-r_0, r_0]
buf.push(F::one() - r[0]);
Expand All @@ -436,16 +422,26 @@ fn build_eq_x_r_helper<F: PrimeField>(r: &[F], buf: &mut Vec<F>) -> Result<(), P
Ok(())
}

/// Decompose an integer into a binary vector in little endian.
pub fn bit_decompose(input: u64, num_var: usize) -> Vec<bool> {
let mut res = Vec::with_capacity(num_var);
let mut i = input;
for _ in 0..num_var {
res.push(i & 1 == 1);
i >>= 1;
}
res
}

#[cfg(test)]
mod test {
use super::*;
use crate::utils::bit_decompose;
use ark_bls12_381::Fr;
use ark_ff::UniformRand;
use ark_std::test_rng;

#[test]
fn test_virtual_polynomial_additions() -> Result<(), PolyIOPErrors> {
fn test_virtual_polynomial_additions() -> Result<(), ArithErrors> {
let mut rng = test_rng();
for nv in 2..5 {
for num_products in 2..5 {
Expand All @@ -468,7 +464,7 @@ mod test {
}

#[test]
fn test_virtual_polynomial_mul_by_mle() -> Result<(), PolyIOPErrors> {
fn test_virtual_polynomial_mul_by_mle() -> Result<(), ArithErrors> {
let mut rng = test_rng();
for nv in 2..5 {
for num_products in 2..5 {
Expand Down
36 changes: 36 additions & 0 deletions hyperplonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,39 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
poly-iop = { path = "../poly-iop" }
pcs = { path = "../pcs" }

ark-std = { version = "^0.3.0", default-features = false }
ark-ec = { version = "^0.3.0", default-features = false }
ark-ff = { version = "^0.3.0", default-features = false }
ark-poly = { version = "^0.3.0", default-features = false }
ark-serialize = { version = "^0.3.0", default-features = false, features = [ "derive" ] }

displaydoc = { version = "0.2.3", default-features = false }
transcript = { path = "../transcript" }
arithmetic = { path = "../arithmetic" }


[dev-dependencies]
ark-bls12-381 = { version = "0.3.0", default-features = false, features = [ "curve" ] }


[features]
default = [ "parallel", "print-trace" ]
# default = [ "parallel" ]
parallel = [
"ark-std/parallel",
"ark-ff/parallel",
"ark-poly/parallel",
"ark-ec/parallel",
"poly-iop/parallel",
"pcs/parallel",
"arithmetic/parallel",
]
print-trace = [
"ark-std/print-trace",
"poly-iop/print-trace",
"pcs/print-trace",
"arithmetic/print-trace",
]
Loading

0 comments on commit a6ea6ac

Please sign in to comment.