Skip to content

Commit

Permalink
pinned memory for fflonk combined monomials
Browse files Browse the repository at this point in the history
  • Loading branch information
saitima committed Jan 13, 2025
1 parent eab05b8 commit ab43d61
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 97 deletions.
130 changes: 105 additions & 25 deletions crates/proof-compression/src/chain.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,74 @@
use circuit_definitions::circuit_definitions::aux_layer::compression_modes::{
CompressionMode1, CompressionMode1ForWrapper, CompressionMode2, CompressionMode3,
CompressionMode4, CompressionMode5ForWrapper,
use circuit_definitions::circuit_definitions::aux_layer::{
compression_modes::{
CompressionMode1, CompressionMode1ForWrapper, CompressionMode2, CompressionMode3,
CompressionMode4, CompressionMode5ForWrapper,
},
ZkSyncCompressionForWrapperProof, ZkSyncCompressionLayerProof,
};
use fflonk::FflonkSnarkVerifierCircuitProof;

use super::*;

pub trait ProofStorage {
fn get_scheduler_proof(&self) -> SchedulerProof;
fn save_compression_layer_proof(&mut self, circuit_id: u8, proof: ZkSyncCompressionLayerProof);
fn save_compression_wrapper_proof(
&mut self,
circuit_id: u8,
proof: ZkSyncCompressionForWrapperProof,
);
fn save_plonk_proof(&mut self, proof: PlonkSnarkVerifierCircuitProof);
fn save_fflonk_proof(&mut self, proof: FflonkSnarkVerifierCircuitProof);
}

pub struct SimpleProofStorage {
compression_layer_storage: std::collections::HashMap<u8, ZkSyncCompressionLayerProof>,
compression_wrapper_storage: std::collections::HashMap<u8, ZkSyncCompressionForWrapperProof>,
fflonk: Option<FflonkSnarkVerifierCircuitProof>,
plonk: Option<PlonkSnarkVerifierCircuitProof>,
}

impl SimpleProofStorage {
pub fn new() -> Self {
Self {
compression_layer_storage: std::collections::HashMap::new(),
compression_wrapper_storage: std::collections::HashMap::new(),
fflonk: None,
plonk: None,
}
}
}

impl ProofStorage for SimpleProofStorage {
fn get_scheduler_proof(&self) -> SchedulerProof {
let scheduler_proof_file =
std::fs::File::open("./data/scheduler_recursive_proof.json").unwrap();
let scheduler_proof: circuit_definitions::circuit_definitions::recursion_layer::ZkSyncRecursionLayerProof =
serde_json::from_reader(&scheduler_proof_file).unwrap();
let scheduler_proof = scheduler_proof.into_inner();
scheduler_proof
}
fn save_compression_layer_proof(&mut self, circuit_id: u8, proof: ZkSyncCompressionLayerProof) {
self.compression_layer_storage.insert(circuit_id, proof);
}

fn save_compression_wrapper_proof(
&mut self,
circuit_id: u8,
proof: ZkSyncCompressionForWrapperProof,
) {
self.compression_wrapper_storage.insert(circuit_id, proof);
}

fn save_plonk_proof(&mut self, proof: PlonkSnarkVerifierCircuitProof) {
self.plonk = Some(proof);
}

fn save_fflonk_proof(&mut self, proof: FflonkSnarkVerifierCircuitProof) {
self.fflonk = Some(proof);
}
}

pub enum SnarkWrapper {
Plonk,
FFfonk,
Expand All @@ -20,61 +84,72 @@ pub type SchedulerProof = franklin_crypto::boojum::cs::implementations::proof::P
GoldilocksExt2,
>;

pub fn run_proof_chain<BS>(
input_proof: SchedulerProof,
pub fn run_proof_chain<BS, PS>(
snark_wrapper: SnarkWrapper,
blob_storage: &BS,
) -> SnarkWrapperProof
where
proof_storage: &mut PS,
) where
BS: BlobStorage,
PS: ProofStorage,
{
match snark_wrapper {
SnarkWrapper::Plonk => {
let proof = run_proof_chain_with_plonk(input_proof, blob_storage);
SnarkWrapperProof::Plonk(proof)
}
SnarkWrapper::FFfonk => {
let proof = run_proof_chain_with_fflonk(input_proof, blob_storage);
SnarkWrapperProof::FFfonk(proof)
}
SnarkWrapper::Plonk => run_proof_chain_with_plonk(blob_storage, proof_storage),
SnarkWrapper::FFfonk => run_proof_chain_with_fflonk(blob_storage, proof_storage),
}
}

pub fn run_proof_chain_with_fflonk<BS>(
input_proof: SchedulerProof,
blob_storage: &BS,
) -> FflonkSnarkVerifierCircuitProof
pub fn run_proof_chain_with_fflonk<BS, PS>(blob_storage: &BS, proof_storage: &mut PS)
where
BS: BlobStorage,
PS: ProofStorage,
{
let context_manager = SimpleContextManager::new();
let start = std::time::Instant::now();
<FflonkSnarkWrapper as SnarkWrapperStep>::run_pre_initialization_tasks();
let compact_raw_crs =
<FflonkSnarkWrapper as SnarkWrapperStep>::load_compact_raw_crs(blob_storage);
let fflonk_precomputation = FflonkSnarkWrapper::get_precomputation(blob_storage);

let input_proof = proof_storage.get_scheduler_proof();

let next_proof =
CompressionMode1::prove_compression_step(input_proof, blob_storage, &context_manager);
let compression_proof_1 = ZkSyncCompressionLayerProof::from_inner(1, next_proof.clone());
proof_storage.save_compression_layer_proof(1, compression_proof_1.clone());

let next_proof = CompressionMode2::prove_compression_step::<_, SimpleContextManager>(
next_proof,
blob_storage,
&context_manager,
);
let compression_proof_2 = ZkSyncCompressionLayerProof::from_inner(2, next_proof.clone());
proof_storage.save_compression_layer_proof(2, compression_proof_2.clone());

let next_proof = CompressionMode3::prove_compression_step::<_, SimpleContextManager>(
next_proof,
blob_storage,
&context_manager,
);
let compression_proof_3 = ZkSyncCompressionLayerProof::from_inner(3, next_proof.clone());
proof_storage.save_compression_layer_proof(3, compression_proof_3.clone());

let next_proof = CompressionMode4::prove_compression_step::<_, SimpleContextManager>(
next_proof,
blob_storage,
&context_manager,
);
let compression_proof_4 = ZkSyncCompressionLayerProof::from_inner(4, next_proof.clone());
proof_storage.save_compression_layer_proof(4, compression_proof_4.clone());

let next_proof = CompressionMode5ForWrapper::prove_compression_step::<_, SimpleContextManager>(
next_proof,
blob_storage,
&context_manager,
);
proof_storage.save_compression_wrapper_proof(
5,
ZkSyncCompressionForWrapperProof::from_inner(5, next_proof.clone()),
);
println!(
"Proving entire compression chain took {}s",
start.elapsed().as_secs()
Expand All @@ -90,14 +165,15 @@ where
"Proving entire chain with snark wrapper took {}s",
start.elapsed().as_secs()
);
final_proof
proof_storage.save_fflonk_proof(final_proof);
}

pub fn precompute_proof_chain_with_fflonk<BS>(blob_storage: &BS)
where
BS: BlobStorageExt,
{
let context_manager = SimpleContextManager::new();
<FflonkSnarkWrapper as SnarkWrapperStep>::run_pre_initialization_tasks();
let compact_raw_crs =
<FflonkSnarkWrapper as SnarkWrapperStep>::load_compact_raw_crs(blob_storage);

Expand Down Expand Up @@ -125,24 +201,28 @@ where
);
}

pub fn run_proof_chain_with_plonk<BS>(
input_proof: SchedulerProof,
blob_storage: &BS,
) -> PlonkSnarkVerifierCircuitProof
pub fn run_proof_chain_with_plonk<BS, PS>(blob_storage: &BS, proof_storage: &mut PS)
where
BS: BlobStorage,
PS: ProofStorage,
{
let context_manager = SimpleContextManager::new();
let start = std::time::Instant::now();
let compact_raw_crs =
<PlonkSnarkWrapper as SnarkWrapperStep>::load_compact_raw_crs(blob_storage);
let plonk_precomputation = PlonkSnarkWrapper::get_precomputation(blob_storage);

let input_proof = proof_storage.get_scheduler_proof();

let next_proof = CompressionMode1ForWrapper::prove_compression_step(
input_proof,
blob_storage,
&context_manager,
);
proof_storage.save_compression_wrapper_proof(
1,
ZkSyncCompressionForWrapperProof::from_inner(1, next_proof.clone()),
);

let final_proof = PlonkSnarkWrapper::prove_snark_wrapper_step(
compact_raw_crs,
Expand All @@ -155,7 +235,7 @@ where
"Entire compression chain with plonk took {}s",
start.elapsed().as_secs()
);
final_proof
proof_storage.save_plonk_proof(final_proof);
}

pub fn precompute_proof_chain_with_plonk<BS>(blob_storage: &BS)
Expand Down
31 changes: 14 additions & 17 deletions crates/proof-compression/src/proof_system/crs.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;
use ::fflonk::bellman::{
self,
kate_commitment::{Crs, CrsForMonomialForm},
use ::fflonk::{
bellman::kate_commitment::{Crs, CrsForMonomialForm},
hardcoded_g2_bases,
};
use bellman::{bn256::Bn256, CurveAffine, Engine, Field, PrimeField};
use byteorder::{BigEndian, ReadBytesExt};
Expand All @@ -10,13 +10,11 @@ use gpu_prover::ManagerConfigs;
pub(crate) fn write_crs_into_raw_compact_form<W: std::io::Write>(
original_crs: &Crs<bellman::bn256::Bn256, CrsForMonomialForm>,
mut dst_raw_compact_crs: W,
num_points: usize,
) -> std::io::Result<()> {
use bellman::CurveAffine;
assert!(num_points <= original_crs.g1_bases.len());
use bellman::{PrimeField, PrimeFieldRepr};
use byteorder::{BigEndian, WriteBytesExt};
assert!(num_points < u32::MAX as usize);
let num_points = original_crs.g1_bases.len();
dst_raw_compact_crs.write_u32::<BigEndian>(num_points as u32)?;
for g1_base in original_crs.g1_bases.iter() {
let (x, y) = g1_base.as_xy();
Expand All @@ -39,6 +37,7 @@ pub(crate) fn read_crs_from_raw_compact_form<R: std::io::Read, A: Allocator + De
mut src_raw_compact_crs: R,
num_g1_points: usize,
) -> std::io::Result<Crs<bellman::compact_bn256::Bn256, CrsForMonomialForm, A>> {
// requested number of bases can be smaller than the available bases
use byteorder::{BigEndian, ReadBytesExt};
let actual_num_points = src_raw_compact_crs.read_u32::<BigEndian>()? as usize;
assert!(num_g1_points <= actual_num_points as usize);
Expand All @@ -51,16 +50,9 @@ pub(crate) fn read_crs_from_raw_compact_form<R: std::io::Read, A: Allocator + De
);
src_raw_compact_crs.read_exact(buf)?;
}
let num_g2_points = 2;
let mut g2_bases = Vec::with_capacity_in(num_g2_points, A::default());
unsafe {
g2_bases.set_len(num_g2_points);
let buf = std::slice::from_raw_parts_mut(
g2_bases.as_mut_ptr() as *mut u8,
num_g2_points * std::mem::size_of::<bellman::compact_bn256::G2Affine>(),
);
src_raw_compact_crs.read_exact(buf)?;
}

let g2_bases = hardcoded_g2_bases::<bellman::compact_bn256::Bn256>().to_vec_in(A::default());

Ok(Crs::<_, CrsForMonomialForm, A>::new_in(g1_bases, g2_bases))
}

Expand All @@ -73,7 +65,8 @@ pub(crate) fn create_compact_raw_crs<W: std::io::Write>(dst: W) {
.max()
.unwrap();
let original_crs = make_crs_from_ignition_transcripts(num_points);
write_crs_into_raw_compact_form(&original_crs, dst, num_points).unwrap();
assert_eq!(original_crs.g1_bases.len(), num_points);
write_crs_into_raw_compact_form(&original_crs, dst).unwrap();
}

fn make_crs_from_ignition_transcripts(num_points: usize) -> Crs<Bn256, CrsForMonomialForm> {
Expand Down Expand Up @@ -277,3 +270,7 @@ fn create_crs_from_ignition_transcript<S: AsRef<std::ffi::OsStr> + ?Sized>(

Ok(new)
}

pub(crate) fn hardcoded_canonical_g2_bases() -> [bellman::bn256::G2Affine; 2] {
::fflonk::hardcoded_g2_bases::<bellman::bn256::Bn256>()
}
16 changes: 13 additions & 3 deletions crates/proof-compression/src/proof_system/fflonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl ProofSystemDefinition for FflonkSnarkWrapper {
type Proof = FflonkSnarkVerifierCircuitProof;
type VK = FflonkSnarkVerifierCircuitVK;
type FinalizationHint = usize;
// type Allocator = GlobalHost; // TODO need global host with preallocated host memory
// Pinned memory with small allocations is expensive e.g Assembly storage
type Allocator = std::alloc::Global;
type ProvingAssembly = FflonkAssembly<SynthesisModeProve, Self::Allocator>;
type Transcript = RollingKeccakTranscript<Self::FieldElement>;
Expand Down Expand Up @@ -69,6 +69,12 @@ impl SnarkWrapperProofSystem for FflonkSnarkWrapper {
CrsForMonomialForm,
Self::Allocator,
>;

fn pre_init() {
let domain_size = ::fflonk::fflonk::L1_VERIFIER_DOMAIN_SIZE_LOG;
Self::Context::init_pinned_memory(domain_size).unwrap();
}

fn load_compact_raw_crs<R: std::io::Read>(src: R) -> Self::CRS {
let domain_size = 1 << ::fflonk::fflonk_cpu::L1_VERIFIER_DOMAIN_SIZE_LOG;
let num_g1_bases_for_crs = ::fflonk::fflonk_cpu::MAX_COMBINED_DEGREE_FACTOR * domain_size;
Expand All @@ -78,7 +84,11 @@ impl SnarkWrapperProofSystem for FflonkSnarkWrapper {
fn init_context(compact_raw_crs: AsyncHandler<Self::CRS>) -> Self::Context {
let compact_raw_crs = compact_raw_crs.wait();
let domain_size = 1 << ::fflonk::fflonk_cpu::L1_VERIFIER_DOMAIN_SIZE_LOG;
let context = Self::Context::init_from_preloaded_crs(domain_size, compact_raw_crs).unwrap();
let context = DeviceContextWithSingleDevice::init_from_preloaded_crs::<Self::Allocator>(
domain_size,
compact_raw_crs,
)
.unwrap();
context
}

Expand Down Expand Up @@ -120,7 +130,7 @@ impl SnarkWrapperProofSystem for FflonkSnarkWrapper {

fn prove_from_witnesses(
_: AsyncHandler<Self::Context>,
_: Vec<Self::FieldElement, Self::Allocator>,
_: Self::ExternalWitnessData,
_: AsyncHandler<Self::Precomputation>,
_: Self::FinalizationHint,
) -> Self::Proof {
Expand Down
3 changes: 2 additions & 1 deletion crates/proof-compression/src/proof_system/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub(crate) trait SnarkWrapperProofSystem: ProofSystemDefinition {
type Circuit;
type Context: Send + Sync + 'static;
type CRS: Send + Sync + 'static;
fn pre_init();
fn init_context(crs: AsyncHandler<Self::CRS>) -> Self::Context;
fn load_compact_raw_crs<R: std::io::Read>(src: R) -> Self::CRS;
fn synthesize_for_proving(circuit: Self::Circuit) -> Self::ProvingAssembly;
Expand All @@ -91,7 +92,7 @@ pub(crate) trait SnarkWrapperProofSystem: ProofSystemDefinition {

fn prove_from_witnesses(
_: AsyncHandler<Self::Context>,
_: Vec<Self::FieldElement, Self::Allocator>,
_: Self::ExternalWitnessData,
_: AsyncHandler<Self::Precomputation>,
_: Self::FinalizationHint,
) -> Self::Proof;
Expand Down
Loading

0 comments on commit ab43d61

Please sign in to comment.