Skip to content

Commit

Permalink
feat: optimize proof verification by running implementing zcash algor…
Browse files Browse the repository at this point in the history
…ithm
  • Loading branch information
th7nder committed Mar 7, 2025
1 parent 722b0a4 commit ee698e0
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 16 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions pallets/proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ frame-system.workspace = true
log = { workspace = true }
num-bigint = { workspace = true }
polka-storage-proofs = { workspace = true, features = ["substrate"] }
pairing = { workspace = true }
itertools = { workspace = true }
primitives = { workspace = true, features = ["serde"] }
rand = { workspace = true, features = ["alloc"] }
rand_chacha = { workspace = true }
Expand All @@ -36,7 +38,7 @@ sp-runtime.workspace = true
sp-std.workspace = true

# Runtime benchmarks
rand_xorshift = { workspace = true, optional = true }
rand_xorshift = { workspace = true }

[dev-dependencies]
blstrs = { workspace = true }
Expand All @@ -59,7 +61,7 @@ runtime-benchmarks = [
"frame-benchmarking/runtime-benchmarks",
"frame-support/runtime-benchmarks",
"frame-system/runtime-benchmarks",
"rand_xorshift",
# "rand_xorshift",
"sp-runtime/runtime-benchmarks",
]
std = [
Expand Down
192 changes: 186 additions & 6 deletions pallets/proofs/src/crypto/groth16.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
//! Groth16 ZK-SNARK related implementations.
use core::ops::{AddAssign, Neg};
use core::ops::{AddAssign, Neg, Mul, MulAssign};

use bls12_381::{multi_miller_loop, G2Prepared};
use codec::{Decode, Encode};
use pairing::{Engine, MillerLoopResult, group::Group};
use ff::Field;
pub use polka_storage_proofs::{Bls12, PrimeField, Proof, Scalar as Fr, VerifyingKey};
use polka_storage_proofs::{Curve, MillerLoopResult, MultiMillerLoop, PrimeCurveAffine};
use polka_storage_proofs::{Curve, MultiMillerLoop, PrimeCurveAffine};
use primitives::randomness::{draw_randomness, DomainSeparationTag};
use rand::SeedableRng;
use scale_info::TypeInfo;

use crate::Vec;
use crate::{fr32::bytes_into_fr_repr_safe, Vec};

/// The prepared verifying key needed in a Groth16 verification.
///
Expand All @@ -17,6 +22,8 @@ use crate::Vec;
#[derive(Clone, Decode, Default, Encode)]
pub(crate) struct PreparedVerifyingKey<E: MultiMillerLoop> {
pub alpha_g1_beta_g2: E::Gt,
pub gamma_g2: E::G2Prepared,
pub delta_g2: E::G2Prepared,
pub neg_gamma_g2: E::G2Prepared,
pub neg_delta_g2: E::G2Prepared,
pub ic: Vec<E::G1Affine>,
Expand All @@ -29,6 +36,8 @@ impl<E: MultiMillerLoop> From<VerifyingKey<E>> for PreparedVerifyingKey<E> {

PreparedVerifyingKey::<E> {
alpha_g1_beta_g2: E::pairing(&vkey.alpha_g1, &vkey.beta_g2),
gamma_g2: vkey.gamma_g2.into(),
delta_g2: vkey.delta_g2.into(),
neg_gamma_g2: gamma.into(),
neg_delta_g2: delta.into(),
ic: vkey.ic,
Expand Down Expand Up @@ -57,7 +66,7 @@ pub fn verify_proof<E: MultiMillerLoop>(
pvk: &PreparedVerifyingKey<E>,
proof: &Proof<E>,
public_inputs: &[E::Fr],
) -> Result<(), VerificationError> {
) -> Result<bool, VerificationError> {
if (public_inputs.len() + 1) != pvk.ic.len() {
return Err(VerificationError::InvalidVerifyingKey);
}
Expand All @@ -84,9 +93,9 @@ pub fn verify_proof<E: MultiMillerLoop>(
])
.final_exponentiation()
{
Ok(())
Ok(true)
} else {
Err(VerificationError::InvalidProof)
Ok(false)
}
}

Expand All @@ -97,4 +106,175 @@ pub enum VerificationError {
InvalidProof,
/// Returned when the given verifying key was invalid.
InvalidVerifyingKey,
InvalidInput,
}

pub(crate) fn le_bytes_to_u64s(le_bytes: &[u8]) -> Vec<u64> {
assert_eq!(
le_bytes.len() % 8,
0,
"length must be divisible by u64 byte length (8-bytes)"
);
le_bytes
.chunks(8)
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect()
}

pub fn verify_proofs_batch<E>(
pvk: &PreparedVerifyingKey<E>,
// rng: &mut R,
proofs: &[Proof<E>],
public_inputs: &[Vec<E::Fr>],
) -> Result<bool, VerificationError>
where
E: MultiMillerLoop,
<E::Fr as PrimeField>::Repr: Sync + Copy,
// R: rand::RngCore,
{
debug_assert_eq!(proofs.len(), public_inputs.len());

for pub_input in public_inputs {
if (pub_input.len() + 1) != pvk.ic.len() {
return Err(VerificationError::InvalidInput);
}
}

let num_inputs = public_inputs[0].len();
let num_proofs = proofs.len();

log::debug!("num proofs? woot: {}", num_proofs);
if num_proofs < 2 {
return verify_proof(pvk, &proofs[0], &public_inputs[0]);
}


log::debug!("ok going down");
let proof_num = proofs.len();

// Choose random coefficients for combining the proofs
let mut rand_z_repr: Vec<_> = Vec::with_capacity(proof_num);
let mut rand_z: Vec<_> = Vec::with_capacity(proof_num);
let mut accum_y = E::Fr::ZERO;

use rand::Rng;
use rand_xorshift::XorShiftRng;
let rng = &mut XorShiftRng::from_seed([
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0xd, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5,
]);

log::debug!("generating random numbers");
for _ in 0..proof_num {
let t: u128 = rng.gen();

let mut repr = E::Fr::ZERO.to_repr();
let mut repr_u64s = le_bytes_to_u64s(repr.as_ref());
assert!(repr_u64s.len() > 1);

repr_u64s[0] = (t & (-1i64 as u128) >> 64) as u64;
repr_u64s[1] = (t >> 64) as u64;

for (i, limb) in repr_u64s.iter().enumerate() {
let start = i * 8;
let stop = start + 8;
repr.as_mut()[start..stop].copy_from_slice(&limb.to_le_bytes());
}

let fr = E::Fr::from_repr(repr).unwrap();
let repr = fr.to_repr();

accum_y.add_assign(&fr);
rand_z_repr.push(repr);
rand_z.push(fr);
}
log::debug!("generated random numbers");

log::debug!("acc_g start");
// Calculate Accum_Gamma sequentially
let mut acc_g = E::G1::identity();
for i in 0..(num_inputs + 1) {
let scalar = if i == 0 {
accum_y
} else {
let idx = i - 1;
let mut cur_sum = rand_z[0];
cur_sum.mul_assign(&public_inputs[0][idx]);

for (pi_mont, mut rand_mont) in
public_inputs.iter().zip(rand_z.iter().copied()).skip(1)
{
let pi_mont = &pi_mont[idx];
rand_mont.mul_assign(pi_mont);
cur_sum.add_assign(&rand_mont);
}
cur_sum
};

let term = pvk.ic[i].mul(scalar);
acc_g.add_assign(&term);
}
let ml_g = E::multi_miller_loop(&[(&acc_g.to_affine(), &pvk.gamma_g2)]);
log::debug!("ml_g done");

// Calculate Accum_Delta sequentially
let mut acc_d = E::G1::identity();
for (proof, rand) in proofs.iter().zip(rand_z.iter()) {
let term = proof.c.mul(*rand);
acc_d.add_assign(&term);
}
let ml_d = E::multi_miller_loop(&[(&acc_d.to_affine(), &pvk.delta_g2)]);
log::debug!("ml_d done");

// Calculate Accum_AB sequentially
// OLD
let mut acc_ab = <E as MultiMillerLoop>::Result::default();
for (proof, rand) in proofs.iter().zip(rand_z.iter()) {
let mul_a = proof.a.mul(*rand);
let cur_neg_b = -proof.b.to_curve();
let term = E::multi_miller_loop(&[(&mul_a.to_affine(), &cur_neg_b.to_affine().into())]);
acc_ab += term;
}
log::debug!("acc_ab done");

// v2
// let mut pairs = Vec::with_capacity(num_proofs + 2);

// for (proof, rand) in proofs.iter().zip(rand_z.iter()) {
// let mul_a = proof.a.mul(*rand).to_affine();
// let neg_b: E::G2Prepared = (-proof.b).into();
// pairs.push((mul_a, neg_b));
// }
// let acc_d_aff = acc_d.to_affine();
// let acc_g_aff = acc_g.to_affine();
// pairs.push((acc_d_aff, pvk.delta_g2.into()));
// pairs.push((acc_g_aff, pvk.gamma_g2.into()));

/* // Step 1: Store owned values in vectors
let mut mul_a_vec: Vec<E::G1Affine> = Vec::with_capacity(num_proofs);
let mut neg_b_vec: Vec<E::G2Prepared> = Vec::with_capacity(num_proofs);
for (proof, rand) in proofs.iter().zip(rand_z.iter()) {
let mul_a = proof.a.mul(*rand).to_affine(); // Compute G1Affine
let neg_b: E::G2Prepared = (-proof.b).into(); // Compute G2Prepared
mul_a_vec.push(mul_a); // Store owned value
neg_b_vec.push(neg_b); // Store owned value
}
// Step 2: Create pairs with references to the stored values
let mut pairs: Vec<(&E::G1Affine, &E::G2Prepared)> = Vec::with_capacity(num_proofs);
for (mul_a, neg_b) in mul_a_vec.iter().zip(neg_b_vec.iter()) {
pairs.push((mul_a, neg_b)); // References to owned data
}
let mut ml_all = E::multi_miller_loop(&pairs); */
let mut ml_all = acc_ab;
ml_all += ml_d;
ml_all += ml_g;

// Calculate Y^-Accum_Y
let accum_y_neg = -accum_y;
let y = pvk.alpha_g1_beta_g2 * accum_y_neg;

let actual = ml_all.final_exponentiation();
Ok(actual == y)
}
1 change: 1 addition & 0 deletions pallets/proofs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ pub mod pallet {
ConstU32<MAX_PROOFS_PER_BLOCK>,
>,
) -> DispatchResult {
log::debug!("got multiple proofs: {}", proofs.len());
let replica_count = replicas.len();
ensure!(replica_count <= post_type.sector_count() * proofs.len(), {
log::error!(
Expand Down
25 changes: 20 additions & 5 deletions pallets/proofs/src/porep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use sha2::{Digest, Sha256};

use crate::{
crypto::groth16::{
prepare_verifying_key, verify_proof, Bls12, Fr, PrimeField, Proof, VerificationError,
VerifyingKey,
prepare_verifying_key, verify_proof, verify_proofs_batch, Bls12, Fr, PrimeField, Proof, VerificationError, VerifyingKey
},
fr32,
graphs::{
Expand Down Expand Up @@ -102,6 +101,7 @@ impl From<VerificationError> for ProofError {
match value {
VerificationError::InvalidProof => ProofError::InvalidProof,
VerificationError::InvalidVerifyingKey => ProofError::InvalidVerifyingKey,
VerificationError::InvalidInput => ProofError::InvalidProof,
}
}
}
Expand Down Expand Up @@ -177,16 +177,31 @@ impl ProofScheme {
};

let pvk = prepare_verifying_key(vk);

for partition_index in 0..proofs.len() {
/* for partition_index in 0..proofs.len() {
let inputs =
self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?;
verify_proof(&pvk, &proofs[partition_index], inputs.as_slice()).inspect_err(|_| {
log::error!(target: LOG_TARGET, "failed to verify partition {}", partition_index);
})?;
}
Ok(()) */

Ok(())
let mut agg_inputs = vec![];
for partition_index in 0..proofs.len() {
let inputs =
self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?;
agg_inputs.push(inputs);
}

let res = verify_proofs_batch(&pvk, &proofs[..], agg_inputs.as_slice()).inspect_err(|_| {
log::error!(target: LOG_TARGET, "failed to verify all partitions");
})?;

if res {
Ok(())
} else {
Err(ProofError::InvalidProof)
}
}

/// References:
Expand Down
25 changes: 22 additions & 3 deletions pallets/proofs/src/post/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use sha2::{Digest, Sha256};

use crate::{
crypto::groth16::{
prepare_verifying_key, verify_proof, Bls12, Fr, Proof, VerificationError, VerifyingKey,
prepare_verifying_key, verify_proof, verify_proofs_batch, Bls12, Fr, Proof, VerificationError, VerifyingKey
},
fr32, Vec,
};
Expand Down Expand Up @@ -82,17 +82,35 @@ impl ProofScheme {
randomness,
sectors: pub_sectors,
};
log::debug!("preparing verifying key");
let pvk = prepare_verifying_key(vk);

log::debug!("generating pulic inputs");
let mut agg_inputs = vec![];
for partition_index in 0..proofs.len() {
let inputs =
self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?;
agg_inputs.push(inputs);
}

log::debug!("generated public inputs, verifying...");
let res = verify_proofs_batch(&pvk, &proofs[..], agg_inputs.as_slice()).inspect_err(|_| {
log::error!(target: LOG_TARGET, "failed to verify all partitions");
})?;

if res {
Ok(())
} else {
Err(ProofError::InvalidProof)
}
/* for partition_index in 0..proofs.len() {
let inputs =
self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?;
verify_proof(&pvk, &proofs[partition_index], inputs.as_slice()).inspect_err(|_| {
log::error!(target: LOG_TARGET, "failed to verify partition {}", partition_index);
})?;
}
Ok(())
Ok(()) */
}

/// References:
Expand Down Expand Up @@ -172,6 +190,7 @@ impl From<VerificationError> for ProofError {
match value {
VerificationError::InvalidProof => ProofError::InvalidProof,
VerificationError::InvalidVerifyingKey => ProofError::InvalidVerifyingKey,
VerificationError::InvalidInput => ProofError::InvalidProof,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions storage-provider/client/src/commands/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ impl ProofsCommand {
let proof_parameters = post::load_groth16_parameters(proof_parameters_path)
.map_err(|e| UtilsCommandError::GeneratePoStError(e))?;


let prover_id = derive_prover_id(signer.account_id());
let proofs = match_post_proof!(
post_type,
Expand Down

0 comments on commit ee698e0

Please sign in to comment.