diff --git a/Cargo.lock b/Cargo.lock index 5f610a75c..2443db97a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12155,8 +12155,10 @@ dependencies = [ "frame-system 28.0.0", "generic-array 1.1.1", "hex", + "itertools 0.13.0", "log", "num-bigint", + "pairing", "parity-scale-codec", "polka-storage-proofs", "primitives", diff --git a/pallets/proofs/Cargo.toml b/pallets/proofs/Cargo.toml index b3d221383..5319cdb9d 100644 --- a/pallets/proofs/Cargo.toml +++ b/pallets/proofs/Cargo.toml @@ -27,6 +27,8 @@ frame-system.workspace = true log = { workspace = true } num-bigint = { workspace = true } polka-storage-proofs = { workspace = true, features = ["substrate"] } +pairing = { workspace = true } +itertools = { workspace = true } primitives = { workspace = true, features = ["serde"] } rand = { workspace = true, features = ["alloc"] } rand_chacha = { workspace = true } @@ -36,7 +38,7 @@ sp-runtime.workspace = true sp-std.workspace = true # Runtime benchmarks -rand_xorshift = { workspace = true, optional = true } +rand_xorshift = { workspace = true } [dev-dependencies] blstrs = { workspace = true } @@ -59,7 +61,7 @@ runtime-benchmarks = [ "frame-benchmarking/runtime-benchmarks", "frame-support/runtime-benchmarks", "frame-system/runtime-benchmarks", - "rand_xorshift", + # "rand_xorshift", "sp-runtime/runtime-benchmarks", ] std = [ diff --git a/pallets/proofs/src/crypto/groth16.rs b/pallets/proofs/src/crypto/groth16.rs index 81daaa5d1..c9f8f8c2a 100644 --- a/pallets/proofs/src/crypto/groth16.rs +++ b/pallets/proofs/src/crypto/groth16.rs @@ -1,13 +1,18 @@ //! Groth16 ZK-SNARK related implementations. -use core::ops::{AddAssign, Neg}; +use core::ops::{AddAssign, Neg, Mul, MulAssign}; +use bls12_381::{multi_miller_loop, G2Prepared}; use codec::{Decode, Encode}; +use pairing::{Engine, MillerLoopResult, group::Group}; +use ff::Field; pub use polka_storage_proofs::{Bls12, PrimeField, Proof, Scalar as Fr, VerifyingKey}; -use polka_storage_proofs::{Curve, MillerLoopResult, MultiMillerLoop, PrimeCurveAffine}; +use polka_storage_proofs::{Curve, MultiMillerLoop, PrimeCurveAffine}; +use primitives::randomness::{draw_randomness, DomainSeparationTag}; +use rand::SeedableRng; use scale_info::TypeInfo; -use crate::Vec; +use crate::{fr32::bytes_into_fr_repr_safe, Vec}; /// The prepared verifying key needed in a Groth16 verification. /// @@ -17,6 +22,8 @@ use crate::Vec; #[derive(Clone, Decode, Default, Encode)] pub(crate) struct PreparedVerifyingKey { pub alpha_g1_beta_g2: E::Gt, + pub gamma_g2: E::G2Prepared, + pub delta_g2: E::G2Prepared, pub neg_gamma_g2: E::G2Prepared, pub neg_delta_g2: E::G2Prepared, pub ic: Vec, @@ -29,6 +36,8 @@ impl From> for PreparedVerifyingKey { PreparedVerifyingKey:: { alpha_g1_beta_g2: E::pairing(&vkey.alpha_g1, &vkey.beta_g2), + gamma_g2: vkey.gamma_g2.into(), + delta_g2: vkey.delta_g2.into(), neg_gamma_g2: gamma.into(), neg_delta_g2: delta.into(), ic: vkey.ic, @@ -57,7 +66,7 @@ pub fn verify_proof( pvk: &PreparedVerifyingKey, proof: &Proof, public_inputs: &[E::Fr], -) -> Result<(), VerificationError> { +) -> Result { if (public_inputs.len() + 1) != pvk.ic.len() { return Err(VerificationError::InvalidVerifyingKey); } @@ -84,9 +93,9 @@ pub fn verify_proof( ]) .final_exponentiation() { - Ok(()) + Ok(true) } else { - Err(VerificationError::InvalidProof) + Ok(false) } } @@ -97,4 +106,175 @@ pub enum VerificationError { InvalidProof, /// Returned when the given verifying key was invalid. InvalidVerifyingKey, + InvalidInput, } + +pub(crate) fn le_bytes_to_u64s(le_bytes: &[u8]) -> Vec { + assert_eq!( + le_bytes.len() % 8, + 0, + "length must be divisible by u64 byte length (8-bytes)" + ); + le_bytes + .chunks(8) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) + .collect() +} + +pub fn verify_proofs_batch( + pvk: &PreparedVerifyingKey, + // rng: &mut R, + proofs: &[Proof], + public_inputs: &[Vec], +) -> Result +where + E: MultiMillerLoop, + ::Repr: Sync + Copy, + // R: rand::RngCore, +{ + debug_assert_eq!(proofs.len(), public_inputs.len()); + + for pub_input in public_inputs { + if (pub_input.len() + 1) != pvk.ic.len() { + return Err(VerificationError::InvalidInput); + } + } + + let num_inputs = public_inputs[0].len(); + let num_proofs = proofs.len(); + + log::debug!("num proofs? woot: {}", num_proofs); + if num_proofs < 2 { + return verify_proof(pvk, &proofs[0], &public_inputs[0]); + } + + + log::debug!("ok going down"); + let proof_num = proofs.len(); + + // Choose random coefficients for combining the proofs + let mut rand_z_repr: Vec<_> = Vec::with_capacity(proof_num); + let mut rand_z: Vec<_> = Vec::with_capacity(proof_num); + let mut accum_y = E::Fr::ZERO; + + use rand::Rng; + use rand_xorshift::XorShiftRng; + let rng = &mut XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0xd, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5, + ]); + + log::debug!("generating random numbers"); + for _ in 0..proof_num { + let t: u128 = rng.gen(); + + let mut repr = E::Fr::ZERO.to_repr(); + let mut repr_u64s = le_bytes_to_u64s(repr.as_ref()); + assert!(repr_u64s.len() > 1); + + repr_u64s[0] = (t & (-1i64 as u128) >> 64) as u64; + repr_u64s[1] = (t >> 64) as u64; + + for (i, limb) in repr_u64s.iter().enumerate() { + let start = i * 8; + let stop = start + 8; + repr.as_mut()[start..stop].copy_from_slice(&limb.to_le_bytes()); + } + + let fr = E::Fr::from_repr(repr).unwrap(); + let repr = fr.to_repr(); + + accum_y.add_assign(&fr); + rand_z_repr.push(repr); + rand_z.push(fr); + } + log::debug!("generated random numbers"); + + log::debug!("acc_g start"); + // Calculate Accum_Gamma sequentially + let mut acc_g = E::G1::identity(); + for i in 0..(num_inputs + 1) { + let scalar = if i == 0 { + accum_y + } else { + let idx = i - 1; + let mut cur_sum = rand_z[0]; + cur_sum.mul_assign(&public_inputs[0][idx]); + + for (pi_mont, mut rand_mont) in + public_inputs.iter().zip(rand_z.iter().copied()).skip(1) + { + let pi_mont = &pi_mont[idx]; + rand_mont.mul_assign(pi_mont); + cur_sum.add_assign(&rand_mont); + } + cur_sum + }; + + let term = pvk.ic[i].mul(scalar); + acc_g.add_assign(&term); + } + let ml_g = E::multi_miller_loop(&[(&acc_g.to_affine(), &pvk.gamma_g2)]); + log::debug!("ml_g done"); + + // Calculate Accum_Delta sequentially + let mut acc_d = E::G1::identity(); + for (proof, rand) in proofs.iter().zip(rand_z.iter()) { + let term = proof.c.mul(*rand); + acc_d.add_assign(&term); + } + let ml_d = E::multi_miller_loop(&[(&acc_d.to_affine(), &pvk.delta_g2)]); + log::debug!("ml_d done"); + + // Calculate Accum_AB sequentially + // OLD + let mut acc_ab = ::Result::default(); + for (proof, rand) in proofs.iter().zip(rand_z.iter()) { + let mul_a = proof.a.mul(*rand); + let cur_neg_b = -proof.b.to_curve(); + let term = E::multi_miller_loop(&[(&mul_a.to_affine(), &cur_neg_b.to_affine().into())]); + acc_ab += term; + } + log::debug!("acc_ab done"); + + // v2 + // let mut pairs = Vec::with_capacity(num_proofs + 2); + + // for (proof, rand) in proofs.iter().zip(rand_z.iter()) { + // let mul_a = proof.a.mul(*rand).to_affine(); + // let neg_b: E::G2Prepared = (-proof.b).into(); + // pairs.push((mul_a, neg_b)); + // } + // let acc_d_aff = acc_d.to_affine(); + // let acc_g_aff = acc_g.to_affine(); + // pairs.push((acc_d_aff, pvk.delta_g2.into())); + // pairs.push((acc_g_aff, pvk.gamma_g2.into())); + + /* // Step 1: Store owned values in vectors + let mut mul_a_vec: Vec = Vec::with_capacity(num_proofs); + let mut neg_b_vec: Vec = Vec::with_capacity(num_proofs); + + for (proof, rand) in proofs.iter().zip(rand_z.iter()) { + let mul_a = proof.a.mul(*rand).to_affine(); // Compute G1Affine + let neg_b: E::G2Prepared = (-proof.b).into(); // Compute G2Prepared + mul_a_vec.push(mul_a); // Store owned value + neg_b_vec.push(neg_b); // Store owned value + } + + // Step 2: Create pairs with references to the stored values + let mut pairs: Vec<(&E::G1Affine, &E::G2Prepared)> = Vec::with_capacity(num_proofs); + for (mul_a, neg_b) in mul_a_vec.iter().zip(neg_b_vec.iter()) { + pairs.push((mul_a, neg_b)); // References to owned data + } + + let mut ml_all = E::multi_miller_loop(&pairs); */ + let mut ml_all = acc_ab; + ml_all += ml_d; + ml_all += ml_g; + + // Calculate Y^-Accum_Y + let accum_y_neg = -accum_y; + let y = pvk.alpha_g1_beta_g2 * accum_y_neg; + + let actual = ml_all.final_exponentiation(); + Ok(actual == y) +} \ No newline at end of file diff --git a/pallets/proofs/src/lib.rs b/pallets/proofs/src/lib.rs index 4d8d34a17..bb00a0e66 100644 --- a/pallets/proofs/src/lib.rs +++ b/pallets/proofs/src/lib.rs @@ -242,6 +242,7 @@ pub mod pallet { ConstU32, >, ) -> DispatchResult { + log::debug!("got multiple proofs: {}", proofs.len()); let replica_count = replicas.len(); ensure!(replica_count <= post_type.sector_count() * proofs.len(), { log::error!( diff --git a/pallets/proofs/src/porep/mod.rs b/pallets/proofs/src/porep/mod.rs index e112c256a..30aee18c0 100644 --- a/pallets/proofs/src/porep/mod.rs +++ b/pallets/proofs/src/porep/mod.rs @@ -12,8 +12,7 @@ use sha2::{Digest, Sha256}; use crate::{ crypto::groth16::{ - prepare_verifying_key, verify_proof, Bls12, Fr, PrimeField, Proof, VerificationError, - VerifyingKey, + prepare_verifying_key, verify_proof, verify_proofs_batch, Bls12, Fr, PrimeField, Proof, VerificationError, VerifyingKey }, fr32, graphs::{ @@ -102,6 +101,7 @@ impl From for ProofError { match value { VerificationError::InvalidProof => ProofError::InvalidProof, VerificationError::InvalidVerifyingKey => ProofError::InvalidVerifyingKey, + VerificationError::InvalidInput => ProofError::InvalidProof, } } } @@ -177,16 +177,31 @@ impl ProofScheme { }; let pvk = prepare_verifying_key(vk); - - for partition_index in 0..proofs.len() { + /* for partition_index in 0..proofs.len() { let inputs = self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?; verify_proof(&pvk, &proofs[partition_index], inputs.as_slice()).inspect_err(|_| { log::error!(target: LOG_TARGET, "failed to verify partition {}", partition_index); })?; } + Ok(()) */ - Ok(()) + let mut agg_inputs = vec![]; + for partition_index in 0..proofs.len() { + let inputs = + self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?; + agg_inputs.push(inputs); + } + + let res = verify_proofs_batch(&pvk, &proofs[..], agg_inputs.as_slice()).inspect_err(|_| { + log::error!(target: LOG_TARGET, "failed to verify all partitions"); + })?; + + if res { + Ok(()) + } else { + Err(ProofError::InvalidProof) + } } /// References: diff --git a/pallets/proofs/src/post/mod.rs b/pallets/proofs/src/post/mod.rs index 66a02e8d4..9438ab8d1 100644 --- a/pallets/proofs/src/post/mod.rs +++ b/pallets/proofs/src/post/mod.rs @@ -13,7 +13,7 @@ use sha2::{Digest, Sha256}; use crate::{ crypto::groth16::{ - prepare_verifying_key, verify_proof, Bls12, Fr, Proof, VerificationError, VerifyingKey, + prepare_verifying_key, verify_proof, verify_proofs_batch, Bls12, Fr, Proof, VerificationError, VerifyingKey }, fr32, Vec, }; @@ -82,9 +82,27 @@ impl ProofScheme { randomness, sectors: pub_sectors, }; + log::debug!("preparing verifying key"); let pvk = prepare_verifying_key(vk); - + log::debug!("generating pulic inputs"); + let mut agg_inputs = vec![]; for partition_index in 0..proofs.len() { + let inputs = + self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?; + agg_inputs.push(inputs); + } + + log::debug!("generated public inputs, verifying..."); + let res = verify_proofs_batch(&pvk, &proofs[..], agg_inputs.as_slice()).inspect_err(|_| { + log::error!(target: LOG_TARGET, "failed to verify all partitions"); + })?; + + if res { + Ok(()) + } else { + Err(ProofError::InvalidProof) + } + /* for partition_index in 0..proofs.len() { let inputs = self.generate_public_inputs(public_inputs.clone(), Some(partition_index))?; verify_proof(&pvk, &proofs[partition_index], inputs.as_slice()).inspect_err(|_| { @@ -92,7 +110,7 @@ impl ProofScheme { })?; } - Ok(()) + Ok(()) */ } /// References: @@ -172,6 +190,7 @@ impl From for ProofError { match value { VerificationError::InvalidProof => ProofError::InvalidProof, VerificationError::InvalidVerifyingKey => ProofError::InvalidVerifyingKey, + VerificationError::InvalidInput => ProofError::InvalidProof, } } } diff --git a/storage-provider/client/src/commands/proofs.rs b/storage-provider/client/src/commands/proofs.rs index 41d2f4da2..3a712d67c 100644 --- a/storage-provider/client/src/commands/proofs.rs +++ b/storage-provider/client/src/commands/proofs.rs @@ -459,6 +459,7 @@ impl ProofsCommand { let proof_parameters = post::load_groth16_parameters(proof_parameters_path) .map_err(|e| UtilsCommandError::GeneratePoStError(e))?; + let prover_id = derive_prover_id(signer.account_id()); let proofs = match_post_proof!( post_type,